run.py 16.6 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
7
8
9
10
import argparse
import json
import logging
import os
import textwrap
from functools import partial

from lm_eval._cli.subcommand import SubCommand
from lm_eval._cli.utils import (
    _int_or_none_list_arg_type,
Baber's avatar
Baber committed
11
12
    key_val_to_dict,
    merge_dicts,
Baber's avatar
Baber committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
    request_caching_arg_to_dict,
    try_parse_json,
)


class Run(SubCommand):
    """Command for running language model evaluation."""

    def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._parser = subparsers.add_parser(
            "run",
            help="Run the evaluation harness on specified tasks",
            description="Evaluate language models on various benchmarks and tasks.",
Baber's avatar
Baber committed
27
            usage="lm-eval run --model <model> --tasks <task> <task> --model_args <arg=value> <arg=value> [options]",
Baber's avatar
Baber committed
28
29
30
            epilog=textwrap.dedent("""
                examples:
                  # Basic evaluation with HuggingFace model
Baber's avatar
Baber committed
31
                  $ lm-eval run --model hf --model_args pretrained=gpt2 dtype=float32 --tasks hellaswag
Baber's avatar
Baber committed
32
33

                  # Evaluate on multiple tasks with few-shot examples
Baber's avatar
Baber committed
34
                  $ lm-eval run --model vllm --model_args pretrained=EleutherAI/gpt-j-6B --tasks arc_easy arc_challenge --num_fewshot 5
Baber's avatar
Baber committed
35
36

                  # Evaluation with custom generation parameters
Baber's avatar
Baber committed
37
                  $ lm-eval run --model hf --model_args pretrained=gpt2 --tasks lambada --gen_kwargs temperature=0.8 top_p=0.95 'stop=["\\n\\n"]'
Baber's avatar
Baber committed
38
39
40
41
42
43
44
45
46

                  # Use configuration file
                  $ lm-eval run --config my_config.yaml --tasks mmlu

                For more information, see: https://github.com/EleutherAI/lm-evaluation-harness
            """),
            formatter_class=argparse.RawDescriptionHelpFormatter,
        )
        self._add_args()
Baber's avatar
Baber committed
47
        self._parser.set_defaults(func=self._execute)
Baber's avatar
Baber committed
48
49
50
51

    def _add_args(self) -> None:
        self._parser = self._parser

Baber's avatar
Baber committed
52
        # Defaults are set in config/evaluate_config.py
Baber's avatar
Baber committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        config_group = self._parser.add_argument_group("configuration")
        config_group.add_argument(
            "--config",
            "-C",
            default=None,
            type=str,
            metavar="YAML_PATH",
            help="Set initial arguments from YAML config",
        )

        # Model and Tasks
        model_group = self._parser.add_argument_group("model and tasks")
        model_group.add_argument(
            "--model",
            "-m",
            type=str,
Baber's avatar
Baber committed
69
            default=None,
Baber's avatar
Baber committed
70
71
72
73
74
75
76
77
            metavar="MODEL_NAME",
            help="Model name (default: hf)",
        )
        model_group.add_argument(
            "--tasks",
            "-t",
            default=None,
            type=str,
Baber's avatar
Baber committed
78
79
            nargs="*",
            metavar="TASK1 TASK2",
Baber's avatar
Baber committed
80
            help=textwrap.dedent("""
Baber's avatar
Baber committed
81
                Space or Comma-separated list of task names or groupings.
Baber's avatar
Baber committed
82
83
84
85
86
87
88
                Use 'lm-eval list tasks' to see all available tasks.
            """).strip(),
        )
        model_group.add_argument(
            "--model_args",
            "-a",
            default=None,
Baber's avatar
Baber committed
89
90
            nargs="*",
            type=key_val_to_dict,
Baber's avatar
Baber committed
91
            metavar="ARGS",
Baber's avatar
Baber committed
92
            help="Model arguments as 'key=val,key2=val2' or `key=val` `key2=val2`",
Baber's avatar
Baber committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        )

        # Evaluation Settings
        eval_group = self._parser.add_argument_group("evaluation settings")
        eval_group.add_argument(
            "--num_fewshot",
            "-f",
            type=int,
            default=None,
            metavar="N",
            help="Number of examples in few-shot context",
        )
        eval_group.add_argument(
            "--batch_size",
            "-b",
            type=str,
            default=argparse.SUPPRESS,
            metavar="auto|auto:N|N",
            help=textwrap.dedent(
                "Batch size: 'auto', 'auto:N' (auto-tune N times), or integer (default: 1)"
            ),
        )
        eval_group.add_argument(
            "--max_batch_size",
            type=int,
            default=None,
            metavar="N",
            help="Maximum batch size when using --batch_size auto",
        )
        eval_group.add_argument(
            "--device",
            type=str,
            default=None,
            metavar="DEVICE",
            help="Device to use (e.g. cuda, cuda:0, cpu, mps)",
        )
        eval_group.add_argument(
            "--gen_kwargs",
Baber's avatar
Baber committed
131
            type=key_val_to_dict,
Baber's avatar
Baber committed
132
            default=None,
Baber's avatar
Baber committed
133
            nargs="*",
Baber's avatar
Baber committed
134
            metavar="KWARGS",
Baber's avatar
Baber committed
135
136
137
138
            help=textwrap.dedent(
                'Generation arguments as `temperature=0,stop=["stop"]` or `key=val` `key2=val2`.'
                "Values should be parsable with ast.literal_eval."
            ),
Baber's avatar
Baber committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        )

        # Data and Output
        data_group = self._parser.add_argument_group("data and output")
        data_group.add_argument(
            "--output_path",
            "-o",
            default=None,
            type=str,
            metavar="OUTPUT_PATH",
            help="Output dir or json file for results (and samples)",
        )
        data_group.add_argument(
            "--log_samples",
            "-s",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Save all model outputs and documents for post-hoc analysis",
        )
        data_group.add_argument(
            "--limit",
            "-L",
            type=float,
            default=None,
            metavar="N|0.0-1.0",
            help="Limit examples per task (integer count or fraction)",
        )
        data_group.add_argument(
            "--samples",
            "-E",
            default=None,
            type=try_parse_json,
Baber's avatar
Baber committed
171
            metavar='"task1": [1,2,3,4,...]"',
Baber's avatar
Baber committed
172
            help=textwrap.dedent(
Baber's avatar
Baber committed
173
174
                "`...` `...` Sample indices for inputs. Incompatible with --limit."
                " Values be parsable with ast.literal_eval."
Baber's avatar
Baber committed
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
            ),
        )

        # Caching and Performance
        cache_group = self._parser.add_argument_group("caching and performance")
        cache_group.add_argument(
            "--use_cache",
            "-c",
            type=str,
            default=None,
            metavar="CACHE_DIR",
            help="SQLite database path for caching model outputs.",
        )
        cache_group.add_argument(
            "--cache_requests",
            type=request_caching_arg_to_dict,
            default=None,
            choices=["true", "refresh", "delete"],
            help="Cache dataset request building (true|refresh|delete)",
        )
        cache_group.add_argument(
            "--check_integrity",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Run task test suite validation",
        )

        # Prompt Formatting
        template_group = self._parser.add_argument_group("instruct formatting")
        template_group.add_argument(
            "--system_instruction",
            type=str,
            default=None,
            metavar="INSTRUCTION",
            help="Add custom system instruction.",
        )
        template_group.add_argument(
            "--apply_chat_template",
            type=str,
            nargs="?",
            const=True,
            default=argparse.SUPPRESS,
            metavar="TEMPLATE",
            help="Apply chat template to prompts (optional template name)",
        )
        template_group.add_argument(
            "--fewshot_as_multiturn",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Use fewshot examples as multi-turn conversation",
        )

        # Task Management
        task_group = self._parser.add_argument_group("task management")
        task_group.add_argument(
            "--include_path",
            type=str,
            default=None,
            metavar="TASK_DIR",
            help="Additional directory for external tasks",
        )

        # Logging and Tracking
        logging_group = self._parser.add_argument_group("logging and tracking")
        logging_group.add_argument(
            "--verbosity",
            "-v",
            type=str.upper,
            default=None,
            metavar="LEVEL",
            help="(Deprecated) Log level. Use LOGLEVEL env var instead",
        )
        logging_group.add_argument(
            "--write_out",
            "-w",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Print prompts for first few documents",
        )
        logging_group.add_argument(
            "--show_config",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Display full task configuration after evaluation",
        )
        logging_group.add_argument(
            "--wandb_args",
Baber's avatar
Baber committed
262
            type=key_val_to_dict,
Baber's avatar
Baber committed
263
264
            default=argparse.SUPPRESS,
            metavar="ARGS",
Baber's avatar
Baber committed
265
            help="Weights & Biases init arguments key=val key2=val2",
Baber's avatar
Baber committed
266
267
268
        )
        logging_group.add_argument(
            "--wandb_config_args",
Baber's avatar
Baber committed
269
            type=key_val_to_dict,
Baber's avatar
Baber committed
270
271
            default=argparse.SUPPRESS,
            metavar="ARGS",
Baber's avatar
Baber committed
272
            help="Weights & Biases config arguments key=val key2=val2",
Baber's avatar
Baber committed
273
274
275
        )
        logging_group.add_argument(
            "--hf_hub_log_args",
Baber's avatar
Baber committed
276
            type=key_val_to_dict,
Baber's avatar
Baber committed
277
278
            default=argparse.SUPPRESS,
            metavar="ARGS",
Baber's avatar
Baber committed
279
            help="Hugging Face Hub logging arguments key=val key2=val2",
Baber's avatar
Baber committed
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
        )

        # Advanced Options
        advanced_group = self._parser.add_argument_group("advanced options")
        advanced_group.add_argument(
            "--predict_only",
            "-x",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Save predictions only, skip metric computation",
        )
        default_seed_string = "0,1234,1234,1234"
        advanced_group.add_argument(
            "--seed",
            type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
Baber's avatar
Baber committed
295
            default=None,
Baber's avatar
Baber committed
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
            metavar="SEED|S1,S2,S3,S4",
            help=textwrap.dedent(f"""
                Random seeds for python,numpy,torch,fewshot (default: {default_seed_string}).
                Use single integer for all, or comma-separated list of 4 values.
                Use 'None' to skip setting a seed. Example: --seed 42 or --seed 0,None,8,52
            """).strip(),
        )
        advanced_group.add_argument(
            "--trust_remote_code",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Allow executing remote code from Hugging Face Hub",
        )
        advanced_group.add_argument(
            "--confirm_run_unsafe_code",
            action="store_true",
            default=argparse.SUPPRESS,
            help="Confirm understanding of unsafe code execution risks",
        )
        advanced_group.add_argument(
            "--metadata",
            type=json.loads,
            default=None,
Baber's avatar
Baber committed
319
            metavar="`key=val` `key2=val2`",
Baber's avatar
Baber committed
320
            help=textwrap.dedent(
Baber's avatar
Baber committed
321
322
                """`key=val` `key2=val` args parsable by ast.literal_eval (merged with model_args),
                required for some tasks such as RULER"""
Baber's avatar
Baber committed
323
324
325
            ),
        )

Baber's avatar
Baber committed
326
327
    @staticmethod
    def _execute(args: argparse.Namespace) -> None:
Baber's avatar
Baber committed
328
        """Runs the evaluation harness with the provided arguments."""
Baber's avatar
Baber committed
329
        os.environ["TOKENIZERS_PARALLELISM"] = "false"
Baber's avatar
Baber committed
330
331
332
333
334
335
336
337
338
339
340
        MERGE_ARGS_DICTS = [
            "model_args",
            "gen_kwargs",
            "wandb_args",
            "wandb_config_args",
            "hf_hub_log_args",
        ]
        for arg_name in MERGE_ARGS_DICTS:
            if current_value := getattr(args, arg_name, None):
                setattr(args, arg_name, merge_dicts(*current_value))

Baber's avatar
Baber committed
341
342
        from lm_eval.config.evaluate_config import EvaluatorConfig

Baber's avatar
Baber committed
343
344
345
        eval_logger = logging.getLogger(__name__)

        # Create and validate config (most validation now occurs in EvaluationConfig)
Baber's avatar
Baber committed
346
347
        cfg = EvaluatorConfig.from_cli(args)

Baber's avatar
Baber committed
348
        from lm_eval import simple_evaluate
Baber's avatar
Baber committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
        from lm_eval.loggers import EvaluationTracker, WandbLogger
        from lm_eval.utils import handle_non_serializable, make_table

        # Set up logging
        if cfg.wandb_args:
            wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args)

        # Set up evaluation tracker
        if cfg.output_path:
            cfg.hf_hub_log_args["output_path"] = cfg.output_path

        if os.environ.get("HF_TOKEN", None):
            cfg.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN")

        evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args)

        # Create task manager (metadata already set up in config validation)
Baber's avatar
Baber committed
366
        task_manager = cfg.process_tasks(cfg.metadata)
Baber's avatar
Baber committed
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457

        # Validation warnings (keep these in CLI as they're logging-specific)
        if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples:
            eval_logger.warning(
                "Pushing samples to the Hub requires --log_samples to be set."
            )

        # Log task selection (tasks already processed in config)
        if cfg.include_path is not None:
            eval_logger.info(f"Including path: {cfg.include_path}")
        eval_logger.info(f"Selected Tasks: {cfg.tasks}")

        # Run evaluation
        results = simple_evaluate(
            model=cfg.model,
            model_args=cfg.model_args,
            tasks=cfg.tasks,
            num_fewshot=cfg.num_fewshot,
            batch_size=cfg.batch_size,
            max_batch_size=cfg.max_batch_size,
            device=cfg.device,
            use_cache=cfg.use_cache,
            cache_requests=cfg.cache_requests.get("cache_requests", False),
            rewrite_requests_cache=cfg.cache_requests.get(
                "rewrite_requests_cache", False
            ),
            delete_requests_cache=cfg.cache_requests.get(
                "delete_requests_cache", False
            ),
            limit=cfg.limit,
            samples=cfg.samples,
            check_integrity=cfg.check_integrity,
            write_out=cfg.write_out,
            log_samples=cfg.log_samples,
            evaluation_tracker=evaluation_tracker,
            system_instruction=cfg.system_instruction,
            apply_chat_template=cfg.apply_chat_template,
            fewshot_as_multiturn=cfg.fewshot_as_multiturn,
            gen_kwargs=cfg.gen_kwargs,
            task_manager=task_manager,
            verbosity=cfg.verbosity,
            predict_only=cfg.predict_only,
            random_seed=cfg.seed[0] if cfg.seed else None,
            numpy_random_seed=cfg.seed[1] if cfg.seed else None,
            torch_random_seed=cfg.seed[2] if cfg.seed else None,
            fewshot_random_seed=cfg.seed[3] if cfg.seed else None,
            confirm_run_unsafe_code=cfg.confirm_run_unsafe_code,
            metadata=cfg.metadata,
        )

        # Process results
        if results is not None:
            if cfg.log_samples:
                samples = results.pop("samples")

            dumped = json.dumps(
                results, indent=2, default=handle_non_serializable, ensure_ascii=False
            )
            if cfg.show_config:
                print(dumped)

            batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

            # W&B logging
            if cfg.wandb_args:
                try:
                    wandb_logger.post_init(results)
                    wandb_logger.log_eval_result()
                    if cfg.log_samples:
                        wandb_logger.log_eval_samples(samples)
                except Exception as e:
                    eval_logger.info(f"Logging to W&B failed: {e}")

            # Save results
            evaluation_tracker.save_results_aggregated(
                results=results, samples=samples if cfg.log_samples else None
            )

            if cfg.log_samples:
                for task_name, _ in results["configs"].items():
                    evaluation_tracker.save_results_samples(
                        task_name=task_name, samples=samples[task_name]
                    )

            if (
                evaluation_tracker.push_results_to_hub
                or evaluation_tracker.push_samples_to_hub
            ):
                evaluation_tracker.recreate_metadata_card()

            # Print results
Baber's avatar
Baber committed
458
            cfg.model_args.pop("trust_remote_code", None)
Baber's avatar
Baber committed
459
460
461
462
463
464
465
466
467
468
469
            print(
                f"{cfg.model} ({cfg.model_args}), gen_kwargs: ({cfg.gen_kwargs}), "
                f"limit: {cfg.limit}, num_fewshot: {cfg.num_fewshot}, "
                f"batch_size: {cfg.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
            )
            print(make_table(results))
            if "groups" in results:
                print(make_table(results, "groups"))

            if cfg.wandb_args:
                wandb_logger.run.finish()