run.py 15.9 KB
Newer Older
Baber's avatar
Baber committed
1
2
3
4
5
6
import argparse
import json
import logging
import os
from functools import partial

Baber's avatar
Baber committed
7
8
9
10
11
12
from lm_eval._cli import SubCommand
from lm_eval._cli.utils import (
    _int_or_none_list_arg_type,
    request_caching_arg_to_dict,
    try_parse_json,
)
Baber's avatar
Baber committed
13
14


Baber's avatar
Baber committed
15
class Run(SubCommand):
Baber's avatar
Baber committed
16
17
18
19
20
21
    """Command for running language model evaluation."""

    def __init__(self, subparsers: argparse._SubParsersAction, *args, **kwargs):
        # Create and configure the parser
        super().__init__(*args, **kwargs)
        parser = subparsers.add_parser(
Baber's avatar
Baber committed
22
            "run",
Baber's avatar
Baber committed
23
24
25
26
            help="Run language model evaluation",
            description="Evaluate language models on various benchmarks and tasks.",
            epilog="""
Examples:
Baber's avatar
Baber committed
27
28
29
  lm-eval run --model hf --model_args pretrained=gpt2 --tasks hellaswag
  lm-eval run --config my_config.yaml --tasks arc_easy,arc_challenge
  lm-eval run --model openai --tasks mmlu --num_fewshot 5
Baber's avatar
Baber committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
            """,
            formatter_class=argparse.RawDescriptionHelpFormatter,
        )

        # Add command-specific arguments
        self._add_args(parser)

        # Set the function to execute for this subcommand
        parser.set_defaults(func=self.execute)

    def _add_args(self, parser: argparse.ArgumentParser) -> None:
        parser.add_argument(
            "--config",
            "-C",
            default=None,
            type=str,
            metavar="DIR/file.yaml",
            help="Path to config with all arguments for `lm-eval`",
        )
        parser.add_argument(
            "--model",
            "-m",
            type=str,
            default="hf",
Baber's avatar
Baber committed
54
            help="Name of model. Default 'hf'",
Baber's avatar
Baber committed
55
56
57
58
59
60
61
62
63
64
65
66
        )
        parser.add_argument(
            "--tasks",
            "-t",
            default=None,
            type=str,
            metavar="task1,task2",
            help="Comma-separated list of task names or task groupings to evaluate on.\nTo get full list of tasks, use one of the commands `lm-eval --tasks {{list_groups,list_subtasks,list_tags,list}}` to list out all available names for task groupings; only (sub)tasks; tags; or all of the above",
        )
        parser.add_argument(
            "--model_args",
            "-a",
Baber's avatar
Baber committed
67
            default=None,
Baber's avatar
Baber committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
            type=try_parse_json,
            help="""Comma separated string or JSON formatted arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32` or '{"pretrained":"EleutherAI/pythia-160m","dtype":"float32"}'.""",
        )
        parser.add_argument(
            "--num_fewshot",
            "-f",
            type=int,
            default=None,
            metavar="N",
            help="Number of examples in few-shot context",
        )
        parser.add_argument(
            "--batch_size",
            "-b",
            type=str,
Baber's avatar
Baber committed
83
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
84
            metavar="auto|auto:N|N",
Baber's avatar
Baber committed
85
            help="Acceptable values are 'auto', 'auto:N' (recompute batchsize N times with time) or N, where N is an integer. Default 1.",
Baber's avatar
Baber committed
86
87
88
89
90
91
92
93
94
95
96
97
        )
        parser.add_argument(
            "--max_batch_size",
            type=int,
            default=None,
            metavar="N",
            help="Maximal batch size to try with --batch_size auto.",
        )
        parser.add_argument(
            "--device",
            type=str,
            default=None,
Baber's avatar
Baber committed
98
            help="Device to use (e.g. cuda, cuda:0, cpu). Model defaults. Default None.",
Baber's avatar
Baber committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        )
        parser.add_argument(
            "--output_path",
            "-o",
            default=None,
            type=str,
            metavar="DIR|DIR/file.json",
            help="Path where result metrics will be saved. Can be either a directory or a .json file. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
        )
        parser.add_argument(
            "--limit",
            "-L",
            type=float,
            default=None,
            metavar="N|0<N<1",
            help="Limit the number of examples per task. "
            "If <1, limit is a percentage of the total number of examples.",
        )
        parser.add_argument(
            "--samples",
            "-E",
            default=None,
Baber's avatar
Baber committed
121
            type=try_parse_json,
Baber's avatar
Baber committed
122
123
124
125
126
127
128
129
130
131
132
133
134
            metavar="/path/to/json",
            help='JSON string or path to JSON file containing doc indices of selected examples to test. Format: {"task_name":[indices],...}',
        )
        parser.add_argument(
            "--use_cache",
            "-c",
            type=str,
            default=None,
            metavar="DIR",
            help="A path to a sqlite db file for caching model responses. `None` if not caching.",
        )
        parser.add_argument(
            "--cache_requests",
Baber's avatar
Baber committed
135
            type=request_caching_arg_to_dict,
Baber's avatar
Baber committed
136
137
138
139
140
141
142
            default=None,
            choices=["true", "refresh", "delete"],
            help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
        )
        parser.add_argument(
            "--check_integrity",
            action="store_true",
Baber's avatar
Baber committed
143
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
144
145
146
147
148
149
            help="Whether to run the relevant part of the test suite for the tasks.",
        )
        parser.add_argument(
            "--write_out",
            "-w",
            action="store_true",
Baber's avatar
Baber committed
150
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
151
152
153
154
155
156
            help="Prints the prompt for the first few documents.",
        )
        parser.add_argument(
            "--log_samples",
            "-s",
            action="store_true",
Baber's avatar
Baber committed
157
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
158
159
160
161
162
163
164
165
166
167
168
169
170
            help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
        )
        parser.add_argument(
            "--system_instruction",
            type=str,
            default=None,
            help="System instruction to be used in the prompt",
        )
        parser.add_argument(
            "--apply_chat_template",
            type=str,
            nargs="?",
            const=True,
Baber's avatar
Baber committed
171
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
172
173
174
175
176
177
178
179
180
181
            help=(
                "If True, apply chat template to the prompt. "
                "Providing `--apply_chat_template` without an argument will apply the default chat template to the prompt. "
                "To apply a specific template from the available list of templates, provide the template name as an argument. "
                "E.g. `--apply_chat_template template_name`"
            ),
        )
        parser.add_argument(
            "--fewshot_as_multiturn",
            action="store_true",
Baber's avatar
Baber committed
182
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
183
184
185
186
187
            help="If True, uses the fewshot as a multi-turn conversation",
        )
        parser.add_argument(
            "--show_config",
            action="store_true",
Baber's avatar
Baber committed
188
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
            help="If True, shows the the full config of all tasks at the end of the evaluation.",
        )
        parser.add_argument(
            "--include_path",
            type=str,
            default=None,
            metavar="DIR",
            help="Additional path to include if there are external tasks to include.",
        )
        parser.add_argument(
            "--gen_kwargs",
            type=try_parse_json,
            default=None,
            help=(
                "Either comma delimited string or JSON formatted arguments for model generation on greedy_until tasks,"
Baber's avatar
Baber committed
204
                """ e.g. '{"do_sample": True, temperature":0.7,"until":["hello"]}' or temperature=0,top_p=0.1."""
Baber's avatar
Baber committed
205
206
207
208
209
210
211
212
213
214
215
216
217
            ),
        )
        parser.add_argument(
            "--verbosity",
            "-v",
            type=str.upper,
            default=None,
            metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
            help="(Deprecated) Controls logging verbosity level. Use the `LOGLEVEL` environment variable instead. Set to DEBUG for detailed output when testing or adding new task configurations.",
        )
        parser.add_argument(
            "--wandb_args",
            type=str,
Baber's avatar
Baber committed
218
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
219
220
221
222
223
            help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval`",
        )
        parser.add_argument(
            "--wandb_config_args",
            type=str,
Baber's avatar
Baber committed
224
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
225
226
227
228
229
            help="Comma separated string arguments passed to wandb.config.update. Use this to trace parameters that aren't already traced by default. eg. `lr=0.01,repeats=3`",
        )
        parser.add_argument(
            "--hf_hub_log_args",
            type=str,
Baber's avatar
Baber committed
230
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
231
232
233
234
235
236
            help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
        )
        parser.add_argument(
            "--predict_only",
            "-x",
            action="store_true",
Baber's avatar
Baber committed
237
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
            help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
        )
        default_seed_string = "0,1234,1234,1234"
        parser.add_argument(
            "--seed",
            type=partial(_int_or_none_list_arg_type, 3, 4, default_seed_string),
            default=default_seed_string,  # for backward compatibility
            help=(
                "Set seed for python's random, numpy, torch, and fewshot sampling.\n"
                "Accepts a comma-separated list of 4 values for python's random, numpy, torch, and fewshot sampling seeds, "
                "respectively, or a single integer to set the same seed for all four.\n"
                f"The values are either an integer or 'None' to not set the seed. Default is `{default_seed_string}` "
                "(for backward compatibility).\n"
                "E.g. `--seed 0,None,8,52` sets `random.seed(0)`, `torch.manual_seed(8)`, and fewshot sampling seed to 52. "
                "Here numpy's seed is not set since the second value is `None`.\n"
                "E.g, `--seed 42` sets all four seeds to 42."
            ),
        )
        parser.add_argument(
            "--trust_remote_code",
            action="store_true",
Baber's avatar
Baber committed
259
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
260
261
262
263
264
            help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
        )
        parser.add_argument(
            "--confirm_run_unsafe_code",
            action="store_true",
Baber's avatar
Baber committed
265
            default=argparse.SUPPRESS,
Baber's avatar
Baber committed
266
267
268
269
270
271
272
273
274
275
276
            help="Confirm that you understand the risks of running unsafe code for tasks that require it",
        )
        parser.add_argument(
            "--metadata",
            type=json.loads,
            default=None,
            help="""JSON string metadata to pass to task configs, for example '{"max_seq_lengths":[4096,8192]}'. Will be merged with model_args. Can also be set in task config.""",
        )

    def execute(self, args: argparse.Namespace) -> None:
        """Execute the evaluation command."""
Baber's avatar
Baber committed
277
278
279
280
        from lm_eval.config.evaluate_config import EvaluatorConfig

        # Create and validate config (most validation now happens in EvaluationConfig)
        cfg = EvaluatorConfig.from_cli(args)
Baber's avatar
Baber committed
281

Baber's avatar
Baber committed
282
        from lm_eval import simple_evaluate, utils
Baber's avatar
Baber committed
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
        from lm_eval.loggers import EvaluationTracker, WandbLogger
        from lm_eval.utils import handle_non_serializable, make_table

        # Set up logging
        if cfg.wandb_args:
            wandb_logger = WandbLogger(cfg.wandb_args, cfg.wandb_config_args)

        utils.setup_logging(cfg.verbosity)
        eval_logger = logging.getLogger(__name__)
        os.environ["TOKENIZERS_PARALLELISM"] = "false"

        # Set up evaluation tracker
        if cfg.output_path:
            cfg.hf_hub_log_args["output_path"] = cfg.output_path

        if os.environ.get("HF_TOKEN", None):
            cfg.hf_hub_log_args["token"] = os.environ.get("HF_TOKEN")

        evaluation_tracker = EvaluationTracker(**cfg.hf_hub_log_args)

        # Create task manager (metadata already set up in config validation)
Baber's avatar
Baber committed
304
        task_manager = cfg.process_tasks()
Baber's avatar
Baber committed
305
306
307
308
309
310
311
312
313
314
315
316
317

        # Validation warnings (keep these in CLI as they're logging-specific)
        if "push_samples_to_hub" in cfg.hf_hub_log_args and not cfg.log_samples:
            eval_logger.warning(
                "Pushing samples to the Hub requires --log_samples to be set."
            )

        # Log task selection (tasks already processed in config)
        if cfg.include_path is not None:
            eval_logger.info(f"Including path: {cfg.include_path}")
        eval_logger.info(f"Selected Tasks: {cfg.tasks}")

        # Run evaluation
Baber's avatar
Baber committed
318
        results = simple_evaluate(
Baber's avatar
Baber committed
319
320
321
322
323
324
325
326
            model=cfg.model,
            model_args=cfg.model_args,
            tasks=cfg.tasks,
            num_fewshot=cfg.num_fewshot,
            batch_size=cfg.batch_size,
            max_batch_size=cfg.max_batch_size,
            device=cfg.device,
            use_cache=cfg.use_cache,
Baber's avatar
Baber committed
327
328
            cache_requests=cfg.cache_requests.get("cache_requests", False),
            rewrite_requests_cache=cfg.cache_requests.get(
Baber's avatar
Baber committed
329
330
                "rewrite_requests_cache", False
            ),
Baber's avatar
Baber committed
331
            delete_requests_cache=cfg.cache_requests.get(
Baber's avatar
Baber committed
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
                "delete_requests_cache", False
            ),
            limit=cfg.limit,
            samples=cfg.samples,
            check_integrity=cfg.check_integrity,
            write_out=cfg.write_out,
            log_samples=cfg.log_samples,
            evaluation_tracker=evaluation_tracker,
            system_instruction=cfg.system_instruction,
            apply_chat_template=cfg.apply_chat_template,
            fewshot_as_multiturn=cfg.fewshot_as_multiturn,
            gen_kwargs=cfg.gen_kwargs,
            task_manager=task_manager,
            verbosity=cfg.verbosity,
            predict_only=cfg.predict_only,
            random_seed=cfg.seed[0] if cfg.seed else None,
            numpy_random_seed=cfg.seed[1] if cfg.seed else None,
            torch_random_seed=cfg.seed[2] if cfg.seed else None,
            fewshot_random_seed=cfg.seed[3] if cfg.seed else None,
            confirm_run_unsafe_code=cfg.confirm_run_unsafe_code,
            metadata=cfg.metadata,
        )

        # Process results
        if results is not None:
            if cfg.log_samples:
                samples = results.pop("samples")

            dumped = json.dumps(
                results, indent=2, default=handle_non_serializable, ensure_ascii=False
            )
            if cfg.show_config:
                print(dumped)

            batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

            # W&B logging
            if cfg.wandb_args:
                try:
                    wandb_logger.post_init(results)
                    wandb_logger.log_eval_result()
                    if cfg.log_samples:
                        wandb_logger.log_eval_samples(samples)
                except Exception as e:
                    eval_logger.info(f"Logging to W&B failed: {e}")

            # Save results
            evaluation_tracker.save_results_aggregated(
                results=results, samples=samples if cfg.log_samples else None
            )

            if cfg.log_samples:
                for task_name, _ in results["configs"].items():
                    evaluation_tracker.save_results_samples(
                        task_name=task_name, samples=samples[task_name]
                    )

            if (
                evaluation_tracker.push_results_to_hub
                or evaluation_tracker.push_samples_to_hub
            ):
                evaluation_tracker.recreate_metadata_card()

            # Print results
            print(
                f"{cfg.model} ({cfg.model_args}), gen_kwargs: ({cfg.gen_kwargs}), "
                f"limit: {cfg.limit}, num_fewshot: {cfg.num_fewshot}, "
                f"batch_size: {cfg.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
            )
            print(make_table(results))
            if "groups" in results:
                print(make_table(results, "groups"))

            if cfg.wandb_args:
                wandb_logger.run.finish()