__main__.py 13.4 KB
Newer Older
1
2
3
import argparse
import json
import logging
lintangsutawika's avatar
lintangsutawika committed
4
import os
lintangsutawika's avatar
lintangsutawika committed
5
import re
lintangsutawika's avatar
lintangsutawika committed
6
import sys
7
from functools import partial
8
from pathlib import Path
haileyschoelkopf's avatar
haileyschoelkopf committed
9
from typing import Union
Leo Gao's avatar
Leo Gao committed
10

11
12
import numpy as np

13
from lm_eval import evaluator, utils
14
from lm_eval.evaluator import request_caching_arg_to_dict
15
from lm_eval.logging_utils import WandbLogger
16
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
17
from lm_eval.utils import make_table, simple_parse_args_string
lintangsutawika's avatar
format  
lintangsutawika committed
18

19

20
DEFAULT_RESULTS_FILE = "results.json"
lintangsutawika's avatar
format  
lintangsutawika committed
21

22

23
def _handle_non_serializable(o):
24
    if isinstance(o, np.int64) or isinstance(o, np.int32):
25
26
27
        return int(o)
    elif isinstance(o, set):
        return list(o)
28
29
    else:
        return str(o)
Fabrizio Milo's avatar
Fabrizio Milo committed
30

31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items != max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )

    return items


haileyschoelkopf's avatar
haileyschoelkopf committed
56
def parse_eval_args() -> argparse.Namespace:
lintangsutawika's avatar
lintangsutawika committed
57
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
Baber Abbasi's avatar
Baber Abbasi committed
58
    parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
lintangsutawika's avatar
lintangsutawika committed
59
60
    parser.add_argument(
        "--tasks",
Baber Abbasi's avatar
Baber Abbasi committed
61
        "-t",
lintangsutawika's avatar
lintangsutawika committed
62
        default=None,
63
        metavar="task1,task2",
lintangsutawika's avatar
lintangsutawika committed
64
        help="To get full list of tasks, use the command lm-eval --tasks list",
lintangsutawika's avatar
lintangsutawika committed
65
    )
66
67
    parser.add_argument(
        "--model_args",
Baber Abbasi's avatar
Baber Abbasi committed
68
        "-a",
69
        default="",
70
        help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
71
    )
lintangsutawika's avatar
lintangsutawika committed
72
    parser.add_argument(
73
        "--num_fewshot",
Baber Abbasi's avatar
Baber Abbasi committed
74
        "-f",
75
        type=int,
76
        default=None,
77
        metavar="N",
78
79
        help="Number of examples in few-shot context",
    )
80
81
    parser.add_argument(
        "--batch_size",
Baber Abbasi's avatar
Baber Abbasi committed
82
        "-b",
83
84
85
86
87
        type=str,
        default=1,
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
lintangsutawika's avatar
lintangsutawika committed
88
89
90
91
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
92
93
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
lintangsutawika's avatar
lintangsutawika committed
94
    )
95
96
97
98
    parser.add_argument(
        "--device",
        type=str,
        default=None,
99
        help="Device to use (e.g. cuda, cuda:0, cpu).",
100
101
102
    )
    parser.add_argument(
        "--output_path",
Baber Abbasi's avatar
Baber Abbasi committed
103
        "-o",
104
105
        default=None,
        type=str,
106
        metavar="DIR|DIR/file.json",
107
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
108
    )
lintangsutawika's avatar
lintangsutawika committed
109
110
    parser.add_argument(
        "--limit",
Baber Abbasi's avatar
Baber Abbasi committed
111
        "-L",
lintangsutawika's avatar
lintangsutawika committed
112
113
        type=float,
        default=None,
114
        metavar="N|0<N<1",
lintangsutawika's avatar
lintangsutawika committed
115
116
117
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
118
119
    parser.add_argument(
        "--use_cache",
Baber Abbasi's avatar
Baber Abbasi committed
120
        "-c",
121
122
        type=str,
        default=None,
123
        metavar="DIR",
124
125
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
126
127
128
129
130
131
132
    parser.add_argument(
        "--cache_requests",
        type=str,
        default=None,
        choices=["true", "refresh", "delete"],
        help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
    )
133
134
135
    parser.add_argument(
        "--check_integrity",
        action="store_true",
136
        help="Whether to run the relevant part of the test suite for the tasks.",
137
138
139
    )
    parser.add_argument(
        "--write_out",
Baber Abbasi's avatar
Baber Abbasi committed
140
        "-w",
141
142
        action="store_true",
        default=False,
143
        help="Prints the prompt for the first few documents.",
144
145
146
    )
    parser.add_argument(
        "--log_samples",
Baber Abbasi's avatar
Baber Abbasi committed
147
        "-s",
148
149
        action="store_true",
        default=False,
150
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
151
    )
152
153
154
155
156
157
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
158
159
160
161
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
162
        metavar="DIR",
163
164
        help="Additional path to include if there are external tasks to include.",
    )
165
166
    parser.add_argument(
        "--gen_kwargs",
167
        default=None,
USVSN Sai Prashanth's avatar
USVSN Sai Prashanth committed
168
169
        help=(
            "String arguments for model generation on greedy_until tasks,"
170
            " e.g. `temperature=0,top_k=0,top_p=0`."
lintangsutawika's avatar
lintangsutawika committed
171
172
173
        ),
    )
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
174
        "--verbosity",
Baber Abbasi's avatar
Baber Abbasi committed
175
176
        "-v",
        type=str.upper,
lintangsutawika's avatar
lintangsutawika committed
177
        default="INFO",
178
179
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
180
    )
181
182
183
184
185
    parser.add_argument(
        "--wandb_args",
        default="",
        help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
    )
Baber Abbasi's avatar
Baber Abbasi committed
186
187
188
189
190
191
192
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
193
194
195
196
197
198
199
200
201
202
203
204
205
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3),
        default="0,1234,1234",  # for backward compatibility
        help=(
            "Set seed for python's random, numpy and torch.\n"
            "Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
            "or a single integer to set the same seed for all three.\n"
            "The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
            "E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all three seeds to 42."
        ),
    )
206
207
    parser.add_argument(
        "--trust_remote_code",
208
        action="store_true",
209
210
211
        help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
    )

Jason Phang's avatar
Jason Phang committed
212
213
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
214

haileyschoelkopf's avatar
haileyschoelkopf committed
215
216
217
218
219
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
    if not args:
        # we allow for args to be passed externally, else we parse them ourselves
        args = parse_eval_args()

220
    if args.wandb_args:
221
        wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
222

223
    eval_logger = utils.eval_logger
lintangsutawika's avatar
lintangsutawika committed
224
    eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
225
    eval_logger.info(f"Verbosity set to {args.verbosity}")
haileyschoelkopf's avatar
haileyschoelkopf committed
226
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
Fabrizio Milo's avatar
Fabrizio Milo committed
227

Baber Abbasi's avatar
Baber Abbasi committed
228
229
230
    if args.predict_only:
        args.log_samples = True
    if (args.log_samples or args.predict_only) and not args.output_path:
231
232
233
        raise ValueError(
            "Specify --output_path if providing --log_samples or --predict_only"
        )
Baber Abbasi's avatar
Baber Abbasi committed
234

235
    initialize_tasks(args.verbosity)
236
    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
237

Leo Gao's avatar
Leo Gao committed
238
    if args.limit:
lintangsutawika's avatar
lintangsutawika committed
239
240
241
        eval_logger.warning(
            " --limit SHOULD ONLY BE USED FOR TESTING."
            "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
Fabrizio Milo's avatar
Fabrizio Milo committed
242
        )
lintangsutawika's avatar
lintangsutawika committed
243

244
    if args.tasks is None:
lintangsutawika's avatar
lintangsutawika committed
245
        eval_logger.error("Need to specify task to evaluate.")
lintangsutawika's avatar
lintangsutawika committed
246
        sys.exit()
247
    elif args.tasks == "list":
lintangsutawika's avatar
lintangsutawika committed
248
        eval_logger.info(
Lintang Sutawika's avatar
Lintang Sutawika committed
249
            "Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
lintangsutawika's avatar
lintangsutawika committed
250
        )
Jason Phang's avatar
Jason Phang committed
251
    else:
252
253
        if os.path.isdir(args.tasks):
            import glob
254

lintangsutawika's avatar
lintangsutawika committed
255
            loaded_task_list = []
256
257
            yaml_path = os.path.join(args.tasks, "*.yaml")
            for yaml_file in glob.glob(yaml_path):
lintangsutawika's avatar
lintangsutawika committed
258
                config = utils.load_yaml_config(yaml_file)
lintangsutawika's avatar
lintangsutawika committed
259
                loaded_task_list.append(config)
260
        else:
261
262
263
            task_list = args.tasks.split(",")
            task_names = task_manager.match_tasks(task_list)
            for task in [task for task in task_list if task not in task_names]:
264
                if os.path.isfile(task):
lintangsutawika's avatar
lintangsutawika committed
265
                    config = utils.load_yaml_config(task)
lintangsutawika's avatar
lintangsutawika committed
266
                    loaded_task_list.append(config)
267
            task_missing = [
268
                task for task in task_list if task not in task_names and "*" not in task
269
            ]  # we don't want errors if a wildcard ("*") task name was used
lintangsutawika's avatar
lintangsutawika committed
270

baberabb's avatar
baberabb committed
271
272
273
274
            if task_missing:
                missing = ", ".join(task_missing)
                eval_logger.error(
                    f"Tasks were not found: {missing}\n"
lintangsutawika's avatar
lintangsutawika committed
275
                    f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
baberabb's avatar
baberabb committed
276
277
                )
                raise ValueError(
278
                    f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
baberabb's avatar
baberabb committed
279
                )
lintangsutawika's avatar
lintangsutawika committed
280

281
282
    if args.output_path:
        path = Path(args.output_path)
Lintang Sutawika's avatar
Lintang Sutawika committed
283
        # check if file or 'dir/results.json' exists
284
285
286
287
        if path.is_file():
            raise FileExistsError(f"File already exists at {path}")
        output_path_file = path.joinpath(DEFAULT_RESULTS_FILE)
        if output_path_file.is_file():
288
            eval_logger.warning(
289
                f"File {output_path_file} already exists. Results will be overwritten."
290
291
292
293
294
295
296
297
298
            )
        # if path json then get parent dir
        elif path.suffix in (".json", ".jsonl"):
            output_path_file = path
            path.parent.mkdir(parents=True, exist_ok=True)
            path = path.parent
        else:
            path.mkdir(parents=True, exist_ok=True)

299
300
    # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
    if args.trust_remote_code:
301
        os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
302
303
304
305
        args.model_args = (
            args.model_args
            + f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
        )
306

lintangsutawika's avatar
lintangsutawika committed
307
    eval_logger.info(f"Selected Tasks: {task_names}")
lintangsutawika's avatar
lintangsutawika committed
308
309
    eval_logger.info("Loading selected tasks...")

310
311
312
313
    request_caching_args = request_caching_arg_to_dict(
        cache_requests=args.cache_requests
    )

314
315
316
317
318
319
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        tasks=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
320
        max_batch_size=args.max_batch_size,
321
        device=args.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
322
        use_cache=args.use_cache,
323
324
        limit=args.limit,
        check_integrity=args.check_integrity,
325
        write_out=args.write_out,
326
        log_samples=args.log_samples,
lintangsutawika's avatar
lintangsutawika committed
327
        gen_kwargs=args.gen_kwargs,
328
        task_manager=task_manager,
329
        verbosity=args.verbosity,
Baber Abbasi's avatar
Baber Abbasi committed
330
        predict_only=args.predict_only,
331
332
333
        random_seed=args.seed[0],
        numpy_random_seed=args.seed[1],
        torch_random_seed=args.seed[2],
334
        **request_caching_args,
335
    )
336

337
    if results is not None:
338
339
        if args.log_samples:
            samples = results.pop("samples")
340
341
342
        dumped = json.dumps(
            results, indent=2, default=_handle_non_serializable, ensure_ascii=False
        )
343
344
        if args.show_config:
            print(dumped)
345

346
347
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

348
349
350
351
352
353
354
355
356
357
        # Add W&B logging
        if args.wandb_args:
            try:
                wandb_logger.post_init(results)
                wandb_logger.log_eval_result()
                if args.log_samples:
                    wandb_logger.log_eval_samples(samples)
            except Exception as e:
                eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

358
        if args.output_path:
359
            output_path_file.open("w", encoding="utf-8").write(dumped)
360

361
362
363
            if args.log_samples:
                for task_name, config in results["configs"].items():
                    output_name = "{}_{}".format(
lintangsutawika's avatar
lintangsutawika committed
364
                        re.sub("/|=", "__", args.model_args), task_name
lintangsutawika's avatar
lintangsutawika committed
365
                    )
366
                    filename = path.joinpath(f"{output_name}.jsonl")
367
                    samples_dumped = json.dumps(
368
369
370
371
                        samples[task_name],
                        indent=2,
                        default=_handle_non_serializable,
                        ensure_ascii=False,
372
                    )
373
                    filename.write_text(samples_dumped, encoding="utf-8")
lintangsutawika's avatar
lintangsutawika committed
374

375
        print(
376
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
377
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
378
        )
379
        print(make_table(results))
lintangsutawika's avatar
lintangsutawika committed
380
        if "groups" in results:
381
            print(make_table(results, "groups"))
Jason Phang's avatar
lib  
Jason Phang committed
382

383
384
385
386
        if args.wandb_args:
            # Tear down wandb run once all the logging is done.
            wandb_logger.run.finish()

387

Jason Phang's avatar
Jason Phang committed
388
if __name__ == "__main__":
haileyschoelkopf's avatar
haileyschoelkopf committed
389
    cli_evaluate()