__main__.py 13.5 KB
Newer Older
1
2
3
import argparse
import json
import logging
lintangsutawika's avatar
lintangsutawika committed
4
import os
lintangsutawika's avatar
lintangsutawika committed
5
import re
6
import sys
7
from functools import partial
8
from pathlib import Path
haileyschoelkopf's avatar
haileyschoelkopf committed
9
from typing import Union
Leo Gao's avatar
Leo Gao committed
10

11
12
import numpy as np

13
from lm_eval import evaluator, utils
14
from lm_eval.evaluator import request_caching_arg_to_dict
15
from lm_eval.logging_utils import WandbLogger
16
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
17
from lm_eval.utils import make_table, simple_parse_args_string
lintangsutawika's avatar
format  
lintangsutawika committed
18

19

20
21
22
DEFAULT_RESULTS_FILE = "results.json"


23
def _handle_non_serializable(o):
24
    if isinstance(o, np.int64) or isinstance(o, np.int32):
25
26
27
        return int(o)
    elif isinstance(o, set):
        return list(o)
28
29
    else:
        return str(o)
Fabrizio Milo's avatar
Fabrizio Milo committed
30

31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items != max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )

    return items


haileyschoelkopf's avatar
haileyschoelkopf committed
56
def parse_eval_args() -> argparse.Namespace:
lintangsutawika's avatar
lintangsutawika committed
57
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
Baber Abbasi's avatar
Baber Abbasi committed
58
    parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
lintangsutawika's avatar
lintangsutawika committed
59
60
    parser.add_argument(
        "--tasks",
Baber Abbasi's avatar
Baber Abbasi committed
61
        "-t",
lintangsutawika's avatar
lintangsutawika committed
62
        default=None,
63
        metavar="task1,task2",
lintangsutawika's avatar
lintangsutawika committed
64
        help="To get full list of tasks, use the command lm-eval --tasks list",
lintangsutawika's avatar
lintangsutawika committed
65
    )
66
67
    parser.add_argument(
        "--model_args",
Baber Abbasi's avatar
Baber Abbasi committed
68
        "-a",
69
        default="",
70
        help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
71
    )
lintangsutawika's avatar
lintangsutawika committed
72
    parser.add_argument(
73
        "--num_fewshot",
Baber Abbasi's avatar
Baber Abbasi committed
74
        "-f",
75
        type=int,
76
        default=None,
77
        metavar="N",
78
79
        help="Number of examples in few-shot context",
    )
80
81
    parser.add_argument(
        "--batch_size",
Baber Abbasi's avatar
Baber Abbasi committed
82
        "-b",
83
84
85
86
87
        type=str,
        default=1,
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
lintangsutawika's avatar
lintangsutawika committed
88
89
90
91
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
92
93
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
lintangsutawika's avatar
lintangsutawika committed
94
    )
95
96
97
98
    parser.add_argument(
        "--device",
        type=str,
        default=None,
99
        help="Device to use (e.g. cuda, cuda:0, cpu).",
100
101
102
    )
    parser.add_argument(
        "--output_path",
Baber Abbasi's avatar
Baber Abbasi committed
103
        "-o",
104
105
        default=None,
        type=str,
106
        metavar="DIR|DIR/file.json",
107
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
108
    )
lintangsutawika's avatar
lintangsutawika committed
109
110
    parser.add_argument(
        "--limit",
Baber Abbasi's avatar
Baber Abbasi committed
111
        "-L",
lintangsutawika's avatar
lintangsutawika committed
112
113
        type=float,
        default=None,
114
        metavar="N|0<N<1",
lintangsutawika's avatar
lintangsutawika committed
115
116
117
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
118
119
    parser.add_argument(
        "--use_cache",
Baber Abbasi's avatar
Baber Abbasi committed
120
        "-c",
121
122
        type=str,
        default=None,
123
        metavar="DIR",
124
125
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
126
127
128
129
130
131
132
    parser.add_argument(
        "--cache_requests",
        type=str,
        default=None,
        choices=["true", "refresh", "delete"],
        help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
    )
133
134
135
    parser.add_argument(
        "--check_integrity",
        action="store_true",
136
        help="Whether to run the relevant part of the test suite for the tasks.",
137
138
139
    )
    parser.add_argument(
        "--write_out",
Baber Abbasi's avatar
Baber Abbasi committed
140
        "-w",
141
142
        action="store_true",
        default=False,
143
        help="Prints the prompt for the first few documents.",
144
145
146
    )
    parser.add_argument(
        "--log_samples",
Baber Abbasi's avatar
Baber Abbasi committed
147
        "-s",
148
149
        action="store_true",
        default=False,
150
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
151
    )
152
153
154
155
156
157
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
158
159
160
161
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
162
        metavar="DIR",
163
164
        help="Additional path to include if there are external tasks to include.",
    )
165
166
    parser.add_argument(
        "--gen_kwargs",
167
        default=None,
USVSN Sai Prashanth's avatar
USVSN Sai Prashanth committed
168
169
        help=(
            "String arguments for model generation on greedy_until tasks,"
170
            " e.g. `temperature=0,top_k=0,top_p=0`."
lintangsutawika's avatar
lintangsutawika committed
171
172
173
        ),
    )
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
174
        "--verbosity",
Baber Abbasi's avatar
Baber Abbasi committed
175
176
        "-v",
        type=str.upper,
lintangsutawika's avatar
lintangsutawika committed
177
        default="INFO",
178
179
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
180
    )
181
182
183
184
185
    parser.add_argument(
        "--wandb_args",
        default="",
        help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
    )
Baber Abbasi's avatar
Baber Abbasi committed
186
187
188
189
190
191
192
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
193
194
195
196
197
198
199
200
201
202
203
204
205
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3),
        default="0,1234,1234",  # for backward compatibility
        help=(
            "Set seed for python's random, numpy and torch.\n"
            "Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
            "or a single integer to set the same seed for all three.\n"
            "The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
            "E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all three seeds to 42."
        ),
    )
206
207
    parser.add_argument(
        "--trust_remote_code",
208
        action="store_true",
209
210
211
        help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
    )

Jason Phang's avatar
Jason Phang committed
212
213
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
214

haileyschoelkopf's avatar
haileyschoelkopf committed
215
216
217
218
219
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
    if not args:
        # we allow for args to be passed externally, else we parse them ourselves
        args = parse_eval_args()

220
    if args.wandb_args:
221
        wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
222

223
    eval_logger = utils.eval_logger
lintangsutawika's avatar
lintangsutawika committed
224
    eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
225
    eval_logger.info(f"Verbosity set to {args.verbosity}")
haileyschoelkopf's avatar
haileyschoelkopf committed
226
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
Fabrizio Milo's avatar
Fabrizio Milo committed
227

Baber Abbasi's avatar
Baber Abbasi committed
228
229
230
    if args.predict_only:
        args.log_samples = True
    if (args.log_samples or args.predict_only) and not args.output_path:
231
232
233
        raise ValueError(
            "Specify --output_path if providing --log_samples or --predict_only"
        )
Baber Abbasi's avatar
Baber Abbasi committed
234

235
    initialize_tasks(args.verbosity)
236
    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
237

Leo Gao's avatar
Leo Gao committed
238
    if args.limit:
lintangsutawika's avatar
lintangsutawika committed
239
240
241
        eval_logger.warning(
            " --limit SHOULD ONLY BE USED FOR TESTING."
            "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
Fabrizio Milo's avatar
Fabrizio Milo committed
242
        )
lintangsutawika's avatar
lintangsutawika committed
243
244
    if args.include_path is not None:
        eval_logger.info(f"Including path: {args.include_path}")
245
        include_path(args.include_path)
lintangsutawika's avatar
lintangsutawika committed
246

247
    if args.tasks is None:
248
249
        eval_logger.error("Need to specify task to evaluate.")
        sys.exit()
250
    elif args.tasks == "list":
lintangsutawika's avatar
lintangsutawika committed
251
        eval_logger.info(
Lintang Sutawika's avatar
Lintang Sutawika committed
252
            "Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
lintangsutawika's avatar
lintangsutawika committed
253
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
254
        sys.exit()
Jason Phang's avatar
Jason Phang committed
255
    else:
256
257
        if os.path.isdir(args.tasks):
            import glob
258
259

            task_names = []
260
261
            yaml_path = os.path.join(args.tasks, "*.yaml")
            for yaml_file in glob.glob(yaml_path):
lintangsutawika's avatar
lintangsutawika committed
262
                config = utils.load_yaml_config(yaml_file)
263
264
                task_names.append(config)
        else:
265
266
267
            task_list = args.tasks.split(",")
            task_names = task_manager.match_tasks(task_list)
            for task in [task for task in task_list if task not in task_names]:
268
                if os.path.isfile(task):
lintangsutawika's avatar
lintangsutawika committed
269
                    config = utils.load_yaml_config(task)
270
                    task_names.append(config)
271
            task_missing = [
272
                task for task in task_list if task not in task_names and "*" not in task
273
            ]  # we don't want errors if a wildcard ("*") task name was used
lintangsutawika's avatar
lintangsutawika committed
274

baberabb's avatar
baberabb committed
275
276
277
278
            if task_missing:
                missing = ", ".join(task_missing)
                eval_logger.error(
                    f"Tasks were not found: {missing}\n"
lintangsutawika's avatar
lintangsutawika committed
279
                    f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
baberabb's avatar
baberabb committed
280
281
                )
                raise ValueError(
282
                    f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
baberabb's avatar
baberabb committed
283
                )
lintangsutawika's avatar
lintangsutawika committed
284

285
286
    if args.output_path:
        path = Path(args.output_path)
Lintang Sutawika's avatar
Lintang Sutawika committed
287
        # check if file or 'dir/results.json' exists
288
289
290
291
        if path.is_file():
            raise FileExistsError(f"File already exists at {path}")
        output_path_file = path.joinpath(DEFAULT_RESULTS_FILE)
        if output_path_file.is_file():
292
            eval_logger.warning(
293
                f"File {output_path_file} already exists. Results will be overwritten."
294
295
296
297
298
299
300
301
302
            )
        # if path json then get parent dir
        elif path.suffix in (".json", ".jsonl"):
            output_path_file = path
            path.parent.mkdir(parents=True, exist_ok=True)
            path = path.parent
        else:
            path.mkdir(parents=True, exist_ok=True)

303
304
    # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
    if args.trust_remote_code:
305
        os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
306
307
308
309
310
        args.model_args = (
            args.model_args
            + f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
        )

lintangsutawika's avatar
lintangsutawika committed
311
    eval_logger.info(f"Selected Tasks: {task_names}")
312
    eval_logger.info("Loading selected tasks...")
313

314
315
316
317
    request_caching_args = request_caching_arg_to_dict(
        cache_requests=args.cache_requests
    )

318
319
320
321
322
323
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        tasks=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
324
        max_batch_size=args.max_batch_size,
325
        device=args.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
326
        use_cache=args.use_cache,
327
328
        limit=args.limit,
        check_integrity=args.check_integrity,
329
        write_out=args.write_out,
330
        log_samples=args.log_samples,
lintangsutawika's avatar
lintangsutawika committed
331
        gen_kwargs=args.gen_kwargs,
332
        task_manager=task_manager,
Baber Abbasi's avatar
Baber Abbasi committed
333
        predict_only=args.predict_only,
334
335
336
        random_seed=args.seed[0],
        numpy_random_seed=args.seed[1],
        torch_random_seed=args.seed[2],
337
        **request_caching_args,
338
    )
339

340
    if results is not None:
341
342
        if args.log_samples:
            samples = results.pop("samples")
343
344
345
        dumped = json.dumps(
            results, indent=2, default=_handle_non_serializable, ensure_ascii=False
        )
346
347
        if args.show_config:
            print(dumped)
348

349
350
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

351
352
353
354
355
356
357
358
359
360
        # Add W&B logging
        if args.wandb_args:
            try:
                wandb_logger.post_init(results)
                wandb_logger.log_eval_result()
                if args.log_samples:
                    wandb_logger.log_eval_samples(samples)
            except Exception as e:
                eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

361
        if args.output_path:
362
            output_path_file.open("w", encoding="utf-8").write(dumped)
363

364
365
366
            if args.log_samples:
                for task_name, config in results["configs"].items():
                    output_name = "{}_{}".format(
lintangsutawika's avatar
lintangsutawika committed
367
                        re.sub("/|=", "__", args.model_args), task_name
lintangsutawika's avatar
lintangsutawika committed
368
                    )
369
                    filename = path.joinpath(f"{output_name}.jsonl")
370
                    samples_dumped = json.dumps(
371
372
373
374
                        samples[task_name],
                        indent=2,
                        default=_handle_non_serializable,
                        ensure_ascii=False,
375
                    )
376
                    filename.write_text(samples_dumped, encoding="utf-8")
lintangsutawika's avatar
lintangsutawika committed
377

378
        print(
379
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
380
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
381
        )
382
        print(make_table(results))
lintangsutawika's avatar
lintangsutawika committed
383
        if "groups" in results:
384
            print(make_table(results, "groups"))
Jason Phang's avatar
lib  
Jason Phang committed
385

386
387
388
389
        if args.wandb_args:
            # Tear down wandb run once all the logging is done.
            wandb_logger.run.finish()

390

Jason Phang's avatar
Jason Phang committed
391
if __name__ == "__main__":
haileyschoelkopf's avatar
haileyschoelkopf committed
392
    cli_evaluate()