validate.py 2.18 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import argparse
import sys

from lm_eval._cli.base import SubCommand


class ValidateCommand(SubCommand):
    """Command for validating tasks."""

    def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
        # Create and configure the parser
        super().__init__(*args, **kwargs)
        parser = subparsers.add_parser(
            "validate",
            help="Validate task configurations",
            description="Validate task configurations and check for errors.",
            epilog="""
Examples:
  lm-eval validate --tasks hellaswag              # Validate single task
  lm-eval validate --tasks arc_easy,arc_challenge # Validate multiple tasks
  lm-eval validate --tasks mmlu --include_path ./custom_tasks
            """,
            formatter_class=argparse.RawDescriptionHelpFormatter,
        )

        # Add command-specific arguments
        self._add_args(parser)

        # Set the function to execute for this subcommand
        parser.set_defaults(func=self.execute)

    def _add_args(self, parser: argparse.ArgumentParser) -> None:
        parser.add_argument(
            "--tasks",
            "-t",
            required=True,
            type=str,
            metavar="task1,task2",
            help="Comma-separated list of task names to validate",
        )
        parser.add_argument(
            "--include_path",
            type=str,
            default=None,
            metavar="DIR",
            help="Additional path to include if there are external tasks.",
        )

    def execute(self, args: argparse.Namespace) -> None:
        """Execute the validate command."""
        from lm_eval.tasks import TaskManager

        task_manager = TaskManager(include_path=args.include_path)
        task_list = args.tasks.split(",")

        print(f"Validating tasks: {task_list}")
        # For now, just validate that tasks exist
        task_names = task_manager.match_tasks(task_list)
        task_missing = [task for task in task_list if task not in task_names]

        if task_missing:
            missing = ", ".join(task_missing)
            print(f"Tasks not found: {missing}")
            sys.exit(1)
        else:
            print("All tasks found and valid")