run.py 15.8 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import argparse
import json
import logging
import os
import textwrap
from functools import partial

from lm_eval._cli.subcommand import SubCommand
from lm_eval._cli.utils import (
    _int_or_none_list_arg_type,
    request_caching_arg_to_dict,
    try_parse_json,
)


class Run(SubCommand):
    """Command for running language model evaluation."""

    def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._parser = subparsers.add_parser(
            "run",
            help="Run the evaluation harness on specified tasks",
            description="Evaluate language models on various benchmarks and tasks.",
            usage="lm-eval run --model <model> --tasks <task1,task2,...> [options]",
            epilog=textwrap.dedent("""
                examples:
                  # Basic evaluation with HuggingFace model
                  $ lm-eval run --model hf --model_args pretrained=gpt2 --tasks hellaswag

                  # Evaluate on multiple tasks with few-shot examples
                  $ lm-eval run --model vllm --model_args pretrained=EleutherAI/gpt-j-6B --tasks arc_easy,arc_challenge --num_fewshot 5

                  # Evaluation with custom generation parameters
                  $ lm-eval run --model hf --model_args pretrained=gpt2 --tasks lambada --gen_kwargs "temperature=0.8,top_p=0.95"

                  # Use configuration file
                  $ lm-eval run --config my_config.yaml --tasks mmlu

                For more information, see: https://github.com/EleutherAI/lm-evaluation-harness
            """),
            formatter_class=argparse.RawDescriptionHelpFormatter,
        )
        self._add_args()
Baber's avatar
Baber committed
45
        self._parser.set_defaults(func=self._execute)
Baber's avatar
Baber committed
46
47
48
49

    def _add_args(self) -> None:
        self._parser = self._parser

Baber's avatar
Baber committed
50
        # Defaults are set in config/evaluate_config.py
Baber's avatar
Baber committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
        config_group = self._parser.add_argument_group("configuration")
        config_group.add_argument(
            "--config",
            "-C",
            default=None,
            type=str,
            metavar="YAML_PATH",
            help="Set initial arguments from YAML config",
        )

        # Model and Tasks
        model_group = self._parser.add_argument_group("model and tasks")
        model_group.add_argument(
            "--model",
            "-m",
            type=str,
Baber's avatar
Baber committed
67
            default=None,
Baber's avatar
Baber committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
            metavar="MODEL_NAME",
            help="Model name (default: hf)",
        )
        model_group.add_argument(
            "--tasks",
            "-t",
            default=None,
            type=str,
            metavar="TASK1,TASK2",
            help=textwrap.dedent("""
                Comma-separated list of task names or groupings.
                Use 'lm-eval list tasks' to see all available tasks.
            """).strip(),
        )
        model_group.add_argument(
            "--model_args",
            "-a",
            default=None,
            type=try_parse_json,
            metavar="ARGS",
            help="Model arguments as 'key=val,key2=val2' or JSON string",
        )

        # Evaluation Settings
        eval_group = self._parser.add_argument_group("evaluation settings")
        eval_group.add_argument(
            "--num_fewshot",
            "-f",
            type=int,
            default=None,
            metavar="N",
            help="Number of examples in few-shot context",
        )
        eval_group.add_argument(
            "--batch_size",
            "-b",
            type=str,
            default=argparse.SUPPRESS,
            metavar="auto|auto:N|N",
            help=textwrap.dedent(
                "Batch size: 'auto', 'auto:N' (auto-tune N times), or integer (default: 1)"
            ),
        )
        eval_group.add_argument(
            "--max_batch_size",
            type=int,
            default=None,
            metavar="N",
            help="Maximum batch size when using --batch_size auto",
        )
        eval_group.add_argument(
            "--device",
            type=str,
            default=None,
            metavar="DEVICE",
            help="Device to use (e.g. cuda, cuda:0, cpu, mps)",
        )
        eval_group.add_argument(
            "--gen_kwargs",
            type=try_parse_json,
            default=None,
            metavar="KWARGS",
            help="Generation arguments as 'key=val,key2=val2' or JSON string",
        )

        # Data and Output
        data_group = self._parser.add_argument_group("data and output")
        data_group.add_argument(
            "--output_path",
            "-o",
            default=None,
            type=str,
            metavar="OUTPUT_PATH",
            help="Output dir or json file for results (and samples)",
        )
        data_group.add_argument(
            "--log_samples",
            "-s",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Save all model outputs and documents for post-hoc analysis",
        )
        data_group.add_argument(
            "--limit",
            "-L",
            type=float,
            default=None,
            metavar="N|0.0-1.0",
            help="Limit examples per task (integer count or fraction)",
        )
        data_group.add_argument(
            "--samples",
            "-E",
            default=None,
            type=try_parse_json,
            metavar="JSON_FILE",
            help=textwrap.dedent(
                'JSON file with specific sample indices for inputs: {"task_name":[indices],...}. Incompatible with --limit.'
            ),
        )

        # Caching and Performance
        cache_group = self._parser.add_argument_group("caching and performance")
        cache_group.add_argument(
            "--use_cache",
            "-c",
            type=str,
            default=None,
            metavar="CACHE_DIR",
            help="SQLite database path for caching model outputs.",
        )
        cache_group.add_argument(
            "--cache_requests",
            type=request_caching_arg_to_dict,
            default=None,
            choices=["true", "refresh", "delete"],
            help="Cache dataset request building (true|refresh|delete)",
        )
        cache_group.add_argument(
            "--check_integrity",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Run task test suite validation",
        )

        # Prompt Formatting
        template_group = self._parser.add_argument_group("instruct formatting")
        template_group.add_argument(
            "--system_instruction",
            type=str,
            default=None,
            metavar="INSTRUCTION",
            help="Add custom system instruction.",
        )
        template_group.add_argument(
            "--apply_chat_template",
            type=str,
            nargs="?",
            const=True,
            default=argparse.SUPPRESS,
            metavar="TEMPLATE",
            help="Apply chat template to prompts (optional template name)",
        )
        template_group.add_argument(
            "--fewshot_as_multiturn",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Use fewshot examples as multi-turn conversation",
        )

        # Task Management
        task_group = self._parser.add_argument_group("task management")
        task_group.add_argument(
            "--include_path",
            type=str,
            default=None,
            metavar="TASK_DIR",
            help="Additional directory for external tasks",
        )

        # Logging and Tracking
        logging_group = self._parser.add_argument_group("logging and tracking")
        logging_group.add_argument(
            "--verbosity",
            "-v",
            type=str.upper,
            default=None,
            metavar="LEVEL",
            help="(Deprecated) Log level. Use LOGLEVEL env var instead",
        )
        logging_group.add_argument(
            "--write_out",
            "-w",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Print prompts for first few documents",
        )
        logging_group.add_argument(
            "--show_config",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Display full task configuration after evaluation",
        )
        logging_group.add_argument(
            "--wandb_args",
            type=str,
            default=argparse.SUPPRESS,
            metavar="ARGS",
            help="Weights & Biases init arguments (key=val,key2=val2)",
        )
        logging_group.add_argument(
            "--wandb_config_args",
            type=str,
            default=argparse.SUPPRESS,
            metavar="ARGS",
            help="Weights & Biases config arguments (key=val,key2=val2)",
        )
        logging_group.add_argument(
            "--hf_hub_log_args",
            type=str,
            default=argparse.SUPPRESS,
            metavar="ARGS",
            help="Hugging Face Hub logging arguments (key=val,key2=val2)",
        )

        # Advanced Options
        advanced_group = self._parser.add_argument_group("advanced options")
        advanced_group.add_argument(
            "--predict_only",
            "-x",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Save predictions only, skip metric computation",
        )
        default_seed_string = "0,1234,1234,1234"
        advanced_group.add_argument(
            "--seed",
            type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
Baber's avatar
Baber committed
286
            default=None,
Baber's avatar
Baber committed
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
            metavar="SEED|S1,S2,S3,S4",
            help=textwrap.dedent(f"""
                Random seeds for python,numpy,torch,fewshot (default: {default_seed_string}).
                Use single integer for all, or comma-separated list of 4 values.
                Use 'None' to skip setting a seed. Example: --seed 42 or --seed 0,None,8,52
            """).strip(),
        )
        advanced_group.add_argument(
            "--trust_remote_code",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Allow executing remote code from Hugging Face Hub",
        )
        advanced_group.add_argument(
            "--confirm_run_unsafe_code",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Confirm understanding of unsafe code execution risks",
        )
        advanced_group.add_argument(
            "--metadata",
            type=json.loads,
            default=None,
            metavar="JSON",
            help=textwrap.dedent(
Baber's avatar
Baber committed
312
                """JSON metadata for task configs (merged with model_args), required for some tasks such as RULER"""
Baber's avatar
Baber committed
313
314
315
            ),
        )

Baber's avatar
Baber committed
316
    def _execute(self, args: argparse.Namespace) -> None:
Baber's avatar
Baber committed
317
        """Runs the evaluation harness with the provided arguments."""
Baber's avatar
Baber committed
318
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
Baber's avatar
Baber committed
319
320
        from lm_eval.config.evaluate_config import EvaluatorConfig

Baber's avatar
Baber committed
321
322
323
        eval_logger = logging.getLogger(__name__)

        # Create and validate config (most validation now occurs in EvaluationConfig)
Baber's avatar
Baber committed
324
325
        cfg = EvaluatorConfig.from_cli(args)

Baber's avatar
Baber committed
326
        from lm_eval import simple_evaluate
Baber's avatar
Baber committed
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
        from lm_eval.loggers import EvaluationTracker, WandbLogger
        from lm_eval.utils import handle_non_serializable, make_table

        # Set up logging
        if cfg.wandb_args:
            wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args)

        # Set up evaluation tracker
        if cfg.output_path:
            cfg.hf_hub_log_args["output_path"] = cfg.output_path

        if os.environ.get("HF_TOKEN", None):
            cfg.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN")

        evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args)

        # Create task manager (metadata already set up in config validation)
Baber's avatar
Baber committed
344
        task_manager = cfg.process_tasks(cfg.metadata)
Baber's avatar
Baber committed
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

        # Validation warnings (keep these in CLI as they're logging-specific)
        if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples:
            eval_logger.warning(
                "Pushing samples to the Hub requires --log_samples to be set."
            )

        # Log task selection (tasks already processed in config)
        if cfg.include_path is not None:
            eval_logger.info(f"Including path: {cfg.include_path}")
        eval_logger.info(f"Selected Tasks: {cfg.tasks}")

        # Run evaluation
        results = simple_evaluate(
            model=cfg.model,
            model_args=cfg.model_args,
            tasks=cfg.tasks,
            num_fewshot=cfg.num_fewshot,
            batch_size=cfg.batch_size,
            max_batch_size=cfg.max_batch_size,
            device=cfg.device,
            use_cache=cfg.use_cache,
            cache_requests=cfg.cache_requests.get("cache_requests", False),
            rewrite_requests_cache=cfg.cache_requests.get(
                "rewrite_requests_cache", False
            ),
            delete_requests_cache=cfg.cache_requests.get(
                "delete_requests_cache", False
            ),
            limit=cfg.limit,
            samples=cfg.samples,
            check_integrity=cfg.check_integrity,
            write_out=cfg.write_out,
            log_samples=cfg.log_samples,
            evaluation_tracker=evaluation_tracker,
            system_instruction=cfg.system_instruction,
            apply_chat_template=cfg.apply_chat_template,
            fewshot_as_multiturn=cfg.fewshot_as_multiturn,
            gen_kwargs=cfg.gen_kwargs,
            task_manager=task_manager,
            verbosity=cfg.verbosity,
            predict_only=cfg.predict_only,
            random_seed=cfg.seed[0] if cfg.seed else None,
            numpy_random_seed=cfg.seed[1] if cfg.seed else None,
            torch_random_seed=cfg.seed[2] if cfg.seed else None,
            fewshot_random_seed=cfg.seed[3] if cfg.seed else None,
            confirm_run_unsafe_code=cfg.confirm_run_unsafe_code,
            metadata=cfg.metadata,
        )

        # Process results
        if results is not None:
            if cfg.log_samples:
                samples = results.pop("samples")

            dumped = json.dumps(
                results, indent=2, default=handle_non_serializable, ensure_ascii=False
            )
            if cfg.show_config:
                print(dumped)

            batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

            # W&B logging
            if cfg.wandb_args:
                try:
                    wandb_logger.post_init(results)
                    wandb_logger.log_eval_result()
                    if cfg.log_samples:
                        wandb_logger.log_eval_samples(samples)
                except Exception as e:
                    eval_logger.info(f"Logging to W&B failed: {e}")

            # Save results
            evaluation_tracker.save_results_aggregated(
                results=results, samples=samples if cfg.log_samples else None
            )

            if cfg.log_samples:
                for task_name, _ in results["configs"].items():
                    evaluation_tracker.save_results_samples(
                        task_name=task_name, samples=samples[task_name]
                    )

            if (
                evaluation_tracker.push_results_to_hub
                or evaluation_tracker.push_samples_to_hub
            ):
                evaluation_tracker.recreate_metadata_card()

            # Print results
Baber's avatar
Baber committed
436
            cfg.model_args.pop("trust_remote_code", None)
Baber's avatar
Baber committed
437
438
439
440
441
442
443
444
445
446
447
            print(
                f"{cfg.model} ({cfg.model_args}), gen_kwargs: ({cfg.gen_kwargs}), "
                f"limit: {cfg.limit}, num_fewshot: {cfg.num_fewshot}, "
                f"batch_size: {cfg.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
            )
            print(make_table(results))
            if "groups" in results:
                print(make_table(results, "groups"))

            if cfg.wandb_args:
                wandb_logger.run.finish()