__main__.py 14 KB
Newer Older
1
2
3
import argparse
import json
import logging
lintangsutawika's avatar
lintangsutawika committed
4
import os
5
import sys
6
from argparse import Namespace
7
from functools import partial
haileyschoelkopf's avatar
haileyschoelkopf committed
8
from typing import Union
Leo Gao's avatar
Leo Gao committed
9

10
from lm_eval import evaluator, utils
11
from lm_eval.evaluator import request_caching_arg_to_dict
12
from lm_eval.logging import EvaluationTracker, WandbLogger
13
from lm_eval.tasks import TaskManager
14
from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string
Fabrizio Milo's avatar
Fabrizio Milo committed
15

16

17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items != max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )

    return items


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def check_argument_types(parser: argparse.ArgumentParser):
    """
    Check to make sure all CLI args are typed, raises error if not
    """
    for action in parser._actions:
        if action.dest != "help" and not action.const:
            if action.type is None:
                raise ValueError(
                    f"Argument '{action.dest}' doesn't have a type specified."
                )
            else:
                continue


def setup_parser() -> argparse.ArgumentParser:
lintangsutawika's avatar
lintangsutawika committed
56
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
57
58
59
    parser.add_argument(
        "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
    )
lintangsutawika's avatar
lintangsutawika committed
60
61
    parser.add_argument(
        "--tasks",
Baber Abbasi's avatar
Baber Abbasi committed
62
        "-t",
lintangsutawika's avatar
lintangsutawika committed
63
        default=None,
64
        type=str,
65
        metavar="task1,task2",
lintangsutawika's avatar
lintangsutawika committed
66
        help="To get full list of tasks, use the command lm-eval --tasks list",
lintangsutawika's avatar
lintangsutawika committed
67
    )
68
69
    parser.add_argument(
        "--model_args",
Baber Abbasi's avatar
Baber Abbasi committed
70
        "-a",
71
        default="",
72
        type=str,
73
        help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
74
    )
lintangsutawika's avatar
lintangsutawika committed
75
    parser.add_argument(
76
        "--num_fewshot",
Baber Abbasi's avatar
Baber Abbasi committed
77
        "-f",
78
        type=int,
79
        default=None,
80
        metavar="N",
81
82
        help="Number of examples in few-shot context",
    )
83
84
    parser.add_argument(
        "--batch_size",
Baber Abbasi's avatar
Baber Abbasi committed
85
        "-b",
86
87
88
89
90
        type=str,
        default=1,
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
lintangsutawika's avatar
lintangsutawika committed
91
92
93
94
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
95
96
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
lintangsutawika's avatar
lintangsutawika committed
97
    )
98
99
100
101
    parser.add_argument(
        "--device",
        type=str,
        default=None,
102
        help="Device to use (e.g. cuda, cuda:0, cpu).",
103
104
105
    )
    parser.add_argument(
        "--output_path",
Baber Abbasi's avatar
Baber Abbasi committed
106
        "-o",
107
108
        default=None,
        type=str,
109
        metavar="DIR|DIR/file.json",
110
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
111
    )
lintangsutawika's avatar
lintangsutawika committed
112
113
    parser.add_argument(
        "--limit",
Baber Abbasi's avatar
Baber Abbasi committed
114
        "-L",
lintangsutawika's avatar
lintangsutawika committed
115
116
        type=float,
        default=None,
117
        metavar="N|0<N<1",
lintangsutawika's avatar
lintangsutawika committed
118
119
120
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
121
122
    parser.add_argument(
        "--use_cache",
Baber Abbasi's avatar
Baber Abbasi committed
123
        "-c",
124
125
        type=str,
        default=None,
126
        metavar="DIR",
127
128
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
129
130
131
132
133
134
135
    parser.add_argument(
        "--cache_requests",
        type=str,
        default=None,
        choices=["true", "refresh", "delete"],
        help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
    )
136
137
138
    parser.add_argument(
        "--check_integrity",
        action="store_true",
139
        help="Whether to run the relevant part of the test suite for the tasks.",
140
141
142
    )
    parser.add_argument(
        "--write_out",
Baber Abbasi's avatar
Baber Abbasi committed
143
        "-w",
144
145
        action="store_true",
        default=False,
146
        help="Prints the prompt for the first few documents.",
147
148
149
    )
    parser.add_argument(
        "--log_samples",
Baber Abbasi's avatar
Baber Abbasi committed
150
        "-s",
151
152
        action="store_true",
        default=False,
153
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
154
    )
155
156
157
158
159
160
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
161
162
163
164
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
165
        metavar="DIR",
166
167
        help="Additional path to include if there are external tasks to include.",
    )
168
169
    parser.add_argument(
        "--gen_kwargs",
170
        type=str,
171
        default=None,
USVSN Sai Prashanth's avatar
USVSN Sai Prashanth committed
172
173
        help=(
            "String arguments for model generation on greedy_until tasks,"
174
            " e.g. `temperature=0,top_k=0,top_p=0`."
lintangsutawika's avatar
lintangsutawika committed
175
176
177
        ),
    )
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
178
        "--verbosity",
Baber Abbasi's avatar
Baber Abbasi committed
179
180
        "-v",
        type=str.upper,
lintangsutawika's avatar
lintangsutawika committed
181
        default="INFO",
182
183
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
184
    )
185
186
    parser.add_argument(
        "--wandb_args",
187
        type=str,
188
189
190
        default="",
        help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
    )
191
192
193
194
195
196
    parser.add_argument(
        "--hf_hub_log_args",
        type=str,
        default="",
        help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
    )
Baber Abbasi's avatar
Baber Abbasi committed
197
198
199
200
201
202
203
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
204
205
206
207
208
209
210
211
212
213
214
215
216
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3),
        default="0,1234,1234",  # for backward compatibility
        help=(
            "Set seed for python's random, numpy and torch.\n"
            "Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
            "or a single integer to set the same seed for all three.\n"
            "The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
            "E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all three seeds to 42."
        ),
    )
217
218
    parser.add_argument(
        "--trust_remote_code",
219
        action="store_true",
220
221
        help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
    )
222
223
224
225
226
    return parser


def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
    check_argument_types(parser)
Jason Phang's avatar
Jason Phang committed
227
228
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
229

haileyschoelkopf's avatar
haileyschoelkopf committed
230
231
232
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
    if not args:
        # we allow for args to be passed externally, else we parse them ourselves
233
234
        parser = setup_parser()
        args = parse_eval_args(parser)
haileyschoelkopf's avatar
haileyschoelkopf committed
235

236
    if args.wandb_args:
237
        wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
238

239
    eval_logger = utils.eval_logger
lintangsutawika's avatar
lintangsutawika committed
240
    eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
241
    eval_logger.info(f"Verbosity set to {args.verbosity}")
haileyschoelkopf's avatar
haileyschoelkopf committed
242
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
Fabrizio Milo's avatar
Fabrizio Milo committed
243

244
245
246
247
248
249
250
251
252
    # update the evaluation tracker args with the output path and the HF token
    args.hf_hub_log_args = f"output_path={args.output_path},token={os.environ.get('HF_TOKEN')},{args.hf_hub_log_args}"
    evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
    evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
    evaluation_tracker.general_config_tracker.log_experiment_args(
        model_source=args.model,
        model_args=args.model_args,
    )

Baber Abbasi's avatar
Baber Abbasi committed
253
254
255
    if args.predict_only:
        args.log_samples = True
    if (args.log_samples or args.predict_only) and not args.output_path:
256
257
258
        raise ValueError(
            "Specify --output_path if providing --log_samples or --predict_only"
        )
Baber Abbasi's avatar
Baber Abbasi committed
259

260
261
    if args.include_path is not None:
        eval_logger.info(f"Including path: {args.include_path}")
262
    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
263

264
265
266
267
268
269
270
271
272
273
274
275
276
    evaluation_tracker_args = Namespace(**evaluation_tracker_args)
    if (
        evaluation_tracker_args.push_results_to_hub
        or evaluation_tracker_args.push_samples_to_hub
    ) and not evaluation_tracker_args.hub_results_org:
        raise ValueError(
            "If push_results_to_hub or push_samples_to_hub is set, results_org must be specified."
        )
    if evaluation_tracker_args.push_samples_to_hub and not args.log_samples:
        eval_logger.warning(
            "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
        )

Leo Gao's avatar
Leo Gao committed
277
    if args.limit:
lintangsutawika's avatar
lintangsutawika committed
278
279
280
        eval_logger.warning(
            " --limit SHOULD ONLY BE USED FOR TESTING."
            "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
Fabrizio Milo's avatar
Fabrizio Milo committed
281
        )
lintangsutawika's avatar
lintangsutawika committed
282

283
    if args.tasks is None:
284
285
        eval_logger.error("Need to specify task to evaluate.")
        sys.exit()
286
    elif args.tasks == "list":
lintangsutawika's avatar
lintangsutawika committed
287
        eval_logger.info(
Lintang Sutawika's avatar
Lintang Sutawika committed
288
            "Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
lintangsutawika's avatar
lintangsutawika committed
289
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
290
        sys.exit()
Jason Phang's avatar
Jason Phang committed
291
    else:
292
293
        if os.path.isdir(args.tasks):
            import glob
294
295

            task_names = []
296
297
            yaml_path = os.path.join(args.tasks, "*.yaml")
            for yaml_file in glob.glob(yaml_path):
lintangsutawika's avatar
lintangsutawika committed
298
                config = utils.load_yaml_config(yaml_file)
299
300
                task_names.append(config)
        else:
301
302
303
            task_list = args.tasks.split(",")
            task_names = task_manager.match_tasks(task_list)
            for task in [task for task in task_list if task not in task_names]:
304
                if os.path.isfile(task):
lintangsutawika's avatar
lintangsutawika committed
305
                    config = utils.load_yaml_config(task)
306
                    task_names.append(config)
307
            task_missing = [
308
                task for task in task_list if task not in task_names and "*" not in task
309
            ]  # we don't want errors if a wildcard ("*") task name was used
lintangsutawika's avatar
lintangsutawika committed
310

baberabb's avatar
baberabb committed
311
312
313
314
            if task_missing:
                missing = ", ".join(task_missing)
                eval_logger.error(
                    f"Tasks were not found: {missing}\n"
lintangsutawika's avatar
lintangsutawika committed
315
                    f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
baberabb's avatar
baberabb committed
316
317
                )
                raise ValueError(
318
                    f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
baberabb's avatar
baberabb committed
319
                )
lintangsutawika's avatar
lintangsutawika committed
320

321
322
    # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
    if args.trust_remote_code:
323
        os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
324
325
326
327
328
        args.model_args = (
            args.model_args
            + f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
        )

lintangsutawika's avatar
lintangsutawika committed
329
    eval_logger.info(f"Selected Tasks: {task_names}")
330

331
332
333
334
    request_caching_args = request_caching_arg_to_dict(
        cache_requests=args.cache_requests
    )

335
336
337
338
339
340
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        tasks=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
341
        max_batch_size=args.max_batch_size,
342
        device=args.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
343
        use_cache=args.use_cache,
344
345
        limit=args.limit,
        check_integrity=args.check_integrity,
346
        write_out=args.write_out,
347
        log_samples=args.log_samples,
lintangsutawika's avatar
lintangsutawika committed
348
        gen_kwargs=args.gen_kwargs,
349
        task_manager=task_manager,
350
        verbosity=args.verbosity,
Baber Abbasi's avatar
Baber Abbasi committed
351
        predict_only=args.predict_only,
352
353
354
        random_seed=args.seed[0],
        numpy_random_seed=args.seed[1],
        torch_random_seed=args.seed[2],
355
        **request_caching_args,
356
    )
357

358
    if results is not None:
359
360
        if args.log_samples:
            samples = results.pop("samples")
361
        dumped = json.dumps(
362
            results, indent=2, default=handle_non_serializable, ensure_ascii=False
363
        )
364
365
        if args.show_config:
            print(dumped)
366

367
368
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

369
370
371
372
373
374
375
376
377
378
        # Add W&B logging
        if args.wandb_args:
            try:
                wandb_logger.post_init(results)
                wandb_logger.log_eval_result()
                if args.log_samples:
                    wandb_logger.log_eval_samples(samples)
            except Exception as e:
                eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

379
380
381
382
383
384
385
        evaluation_tracker.save_results_aggregated(results=results, samples=samples)

        if args.log_samples:
            for task_name, config in results["configs"].items():
                evaluation_tracker.save_results_samples(
                    task_name=task_name, samples=samples[task_name]
                )
lintangsutawika's avatar
lintangsutawika committed
386

387
        print(
388
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
389
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
390
        )
391
        print(make_table(results))
lintangsutawika's avatar
lintangsutawika committed
392
        if "groups" in results:
393
            print(make_table(results, "groups"))
Jason Phang's avatar
lib  
Jason Phang committed
394

395
396
397
398
        if args.wandb_args:
            # Tear down wandb run once all the logging is done.
            wandb_logger.run.finish()

399

Jason Phang's avatar
Jason Phang committed
400
if __name__ == "__main__":
haileyschoelkopf's avatar
haileyschoelkopf committed
401
    cli_evaluate()