list.py 2.14 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse

from lm_eval._cli.base import SubCommand


class ListCommand(SubCommand):
    """Command for listing available tasks."""

    def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
        # Create and configure the parser
        super().__init__(*args, **kwargs)
        parser = subparsers.add_parser(
            "list",
            help="List available tasks, groups, subtasks, or tags",
            description="List available tasks, groups, subtasks, or tags from the evaluation harness.",
            epilog="""
Examples:
  lm-eval list tasks         # List all available tasks
  lm-eval list groups        # List task groups only
  lm-eval list subtasks      # List subtasks only
  lm-eval list tags          # List available tags
            """,
            formatter_class=argparse.RawDescriptionHelpFormatter,
        )

        # Add command-specific arguments
        self._add_args(parser)

        # Set the function to execute for this subcommand
        parser.set_defaults(func=self.execute)

    def _add_args(self, parser: argparse.ArgumentParser) -> None:
        parser.add_argument(
            "what",
            choices=["tasks", "groups", "subtasks", "tags"],
            help="What to list: tasks (all), groups, subtasks, or tags",
        )
        parser.add_argument(
            "--include_path",
            type=str,
            default=None,
            metavar="DIR",
            help="Additional path to include if there are external tasks.",
        )

    def execute(self, args: argparse.Namespace) -> None:
        """Execute the list command."""
        from lm_eval.tasks import TaskManager

        task_manager = TaskManager(include_path=args.include_path)

        if args.what == "tasks":
            print(task_manager.list_all_tasks())
        elif args.what == "groups":
            print(task_manager.list_all_tasks(list_subtasks=False, list_tags=False))
        elif args.what == "subtasks":
            print(task_manager.list_all_tasks(list_groups=False, list_tags=False))
        elif args.what == "tags":
            print(task_manager.list_all_tasks(list_groups=False, list_subtasks=False))