__main__.py 14.1 KB
Newer Older
1
2
3
import argparse
import json
import logging
lintangsutawika's avatar
lintangsutawika committed
4
import os
lintangsutawika's avatar
lintangsutawika committed
5
import re
lintangsutawika's avatar
lintangsutawika committed
6
import sys
7
from functools import partial
8
from pathlib import Path
haileyschoelkopf's avatar
haileyschoelkopf committed
9
from typing import Union
Leo Gao's avatar
Leo Gao committed
10

11
12
import numpy as np

13
from lm_eval import evaluator, utils
14
from lm_eval.evaluator import request_caching_arg_to_dict
15
from lm_eval.logging_utils import WandbLogger
16
from lm_eval.tasks import TaskManager
17
from lm_eval.utils import make_table, simple_parse_args_string
lintangsutawika's avatar
format  
lintangsutawika committed
18

19

20
DEFAULT_RESULTS_FILE = "results.json"
lintangsutawika's avatar
format  
lintangsutawika committed
21

22

23
def _handle_non_serializable(o):
24
    if isinstance(o, np.int64) or isinstance(o, np.int32):
25
26
27
        return int(o)
    elif isinstance(o, set):
        return list(o)
28
29
    else:
        return str(o)
Fabrizio Milo's avatar
Fabrizio Milo committed
30

31

32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items != max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )

    return items


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def check_argument_types(parser: argparse.ArgumentParser):
    """
    Check to make sure all CLI args are typed, raises error if not
    """
    for action in parser._actions:
        if action.dest != "help" and not action.const:
            if action.type is None:
                raise ValueError(
                    f"Argument '{action.dest}' doesn't have a type specified."
                )
            else:
                continue


def setup_parser() -> argparse.ArgumentParser:
lintangsutawika's avatar
lintangsutawika committed
71
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
72
73
74
    parser.add_argument(
        "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
    )
lintangsutawika's avatar
lintangsutawika committed
75
76
    parser.add_argument(
        "--tasks",
Baber Abbasi's avatar
Baber Abbasi committed
77
        "-t",
lintangsutawika's avatar
lintangsutawika committed
78
        default=None,
79
        type=str,
80
        metavar="task1,task2",
lintangsutawika's avatar
lintangsutawika committed
81
        help="To get full list of tasks, use the command lm-eval --tasks list",
lintangsutawika's avatar
lintangsutawika committed
82
    )
83
84
    parser.add_argument(
        "--model_args",
Baber Abbasi's avatar
Baber Abbasi committed
85
        "-a",
86
        default="",
87
        type=str,
88
        help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
89
    )
lintangsutawika's avatar
lintangsutawika committed
90
    parser.add_argument(
91
        "--num_fewshot",
Baber Abbasi's avatar
Baber Abbasi committed
92
        "-f",
93
        type=int,
94
        default=None,
95
        metavar="N",
96
97
        help="Number of examples in few-shot context",
    )
98
99
    parser.add_argument(
        "--batch_size",
Baber Abbasi's avatar
Baber Abbasi committed
100
        "-b",
101
102
103
104
105
        type=str,
        default=1,
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
lintangsutawika's avatar
lintangsutawika committed
106
107
108
109
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
110
111
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
lintangsutawika's avatar
lintangsutawika committed
112
    )
113
114
115
116
    parser.add_argument(
        "--device",
        type=str,
        default=None,
117
        help="Device to use (e.g. cuda, cuda:0, cpu).",
118
119
120
    )
    parser.add_argument(
        "--output_path",
Baber Abbasi's avatar
Baber Abbasi committed
121
        "-o",
122
123
        default=None,
        type=str,
124
        metavar="DIR|DIR/file.json",
125
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
126
    )
lintangsutawika's avatar
lintangsutawika committed
127
128
    parser.add_argument(
        "--limit",
Baber Abbasi's avatar
Baber Abbasi committed
129
        "-L",
lintangsutawika's avatar
lintangsutawika committed
130
131
        type=float,
        default=None,
132
        metavar="N|0<N<1",
lintangsutawika's avatar
lintangsutawika committed
133
134
135
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
136
137
    parser.add_argument(
        "--use_cache",
Baber Abbasi's avatar
Baber Abbasi committed
138
        "-c",
139
140
        type=str,
        default=None,
141
        metavar="DIR",
142
143
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
144
145
146
147
148
149
150
    parser.add_argument(
        "--cache_requests",
        type=str,
        default=None,
        choices=["true", "refresh", "delete"],
        help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
    )
151
152
153
    parser.add_argument(
        "--check_integrity",
        action="store_true",
154
        help="Whether to run the relevant part of the test suite for the tasks.",
155
156
157
    )
    parser.add_argument(
        "--write_out",
Baber Abbasi's avatar
Baber Abbasi committed
158
        "-w",
159
160
        action="store_true",
        default=False,
161
        help="Prints the prompt for the first few documents.",
162
163
164
    )
    parser.add_argument(
        "--log_samples",
Baber Abbasi's avatar
Baber Abbasi committed
165
        "-s",
166
167
        action="store_true",
        default=False,
168
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
169
    )
170
171
172
173
174
175
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
176
177
178
179
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
180
        metavar="DIR",
181
182
        help="Additional path to include if there are external tasks to include.",
    )
183
184
    parser.add_argument(
        "--gen_kwargs",
185
        type=str,
186
        default=None,
USVSN Sai Prashanth's avatar
USVSN Sai Prashanth committed
187
188
        help=(
            "String arguments for model generation on greedy_until tasks,"
189
            " e.g. `temperature=0,top_k=0,top_p=0`."
lintangsutawika's avatar
lintangsutawika committed
190
191
192
        ),
    )
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
193
        "--verbosity",
Baber Abbasi's avatar
Baber Abbasi committed
194
195
        "-v",
        type=str.upper,
lintangsutawika's avatar
lintangsutawika committed
196
        default="INFO",
197
198
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
199
    )
200
201
    parser.add_argument(
        "--wandb_args",
202
        type=str,
203
204
205
        default="",
        help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
    )
Baber Abbasi's avatar
Baber Abbasi committed
206
207
208
209
210
211
212
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
213
214
215
216
217
218
219
220
221
222
223
224
225
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3),
        default="0,1234,1234",  # for backward compatibility
        help=(
            "Set seed for python's random, numpy and torch.\n"
            "Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
            "or a single integer to set the same seed for all three.\n"
            "The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
            "E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all three seeds to 42."
        ),
    )
226
227
    parser.add_argument(
        "--trust_remote_code",
228
        action="store_true",
229
230
231
        help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
    )

232
233
234
235
236
    return parser


def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
    check_argument_types(parser)
Jason Phang's avatar
Jason Phang committed
237
238
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
239

haileyschoelkopf's avatar
haileyschoelkopf committed
240
241
242
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
    if not args:
        # we allow for args to be passed externally, else we parse them ourselves
243
244
        parser = setup_parser()
        args = parse_eval_args(parser)
haileyschoelkopf's avatar
haileyschoelkopf committed
245

246
    if args.wandb_args:
247
        wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
248

249
    eval_logger = utils.eval_logger
lintangsutawika's avatar
lintangsutawika committed
250
    eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
251
    eval_logger.info(f"Verbosity set to {args.verbosity}")
haileyschoelkopf's avatar
haileyschoelkopf committed
252
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
Fabrizio Milo's avatar
Fabrizio Milo committed
253

Baber Abbasi's avatar
Baber Abbasi committed
254
255
256
    if args.predict_only:
        args.log_samples = True
    if (args.log_samples or args.predict_only) and not args.output_path:
257
258
259
        raise ValueError(
            "Specify --output_path if providing --log_samples or --predict_only"
        )
Baber Abbasi's avatar
Baber Abbasi committed
260

261
262
    if args.include_path is not None:
        eval_logger.info(f"Including path: {args.include_path}")
263
    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
264

Leo Gao's avatar
Leo Gao committed
265
    if args.limit:
lintangsutawika's avatar
lintangsutawika committed
266
267
268
        eval_logger.warning(
            " --limit SHOULD ONLY BE USED FOR TESTING."
            "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
Fabrizio Milo's avatar
Fabrizio Milo committed
269
        )
lintangsutawika's avatar
lintangsutawika committed
270

271
    if args.tasks is None:
lintangsutawika's avatar
lintangsutawika committed
272
        eval_logger.error("Need to specify task to evaluate.")
lintangsutawika's avatar
lintangsutawika committed
273
        sys.exit()
274
    elif args.tasks == "list":
lintangsutawika's avatar
lintangsutawika committed
275
        eval_logger.info(
Lintang Sutawika's avatar
Lintang Sutawika committed
276
            "Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
lintangsutawika's avatar
lintangsutawika committed
277
        )
Jason Phang's avatar
Jason Phang committed
278
    else:
279
280
        if os.path.isdir(args.tasks):
            import glob
281

lintangsutawika's avatar
lintangsutawika committed
282
            loaded_task_list = []
283
284
            yaml_path = os.path.join(args.tasks, "*.yaml")
            for yaml_file in glob.glob(yaml_path):
lintangsutawika's avatar
lintangsutawika committed
285
                config = utils.load_yaml_config(yaml_file)
lintangsutawika's avatar
lintangsutawika committed
286
                loaded_task_list.append(config)
287
        else:
288
289
290
            task_list = args.tasks.split(",")
            task_names = task_manager.match_tasks(task_list)
            for task in [task for task in task_list if task not in task_names]:
291
                if os.path.isfile(task):
lintangsutawika's avatar
lintangsutawika committed
292
                    config = utils.load_yaml_config(task)
lintangsutawika's avatar
lintangsutawika committed
293
                    loaded_task_list.append(config)
294
            task_missing = [
295
                task for task in task_list if task not in task_names and "*" not in task
296
            ]  # we don't want errors if a wildcard ("*") task name was used
lintangsutawika's avatar
lintangsutawika committed
297

baberabb's avatar
baberabb committed
298
299
300
301
            if task_missing:
                missing = ", ".join(task_missing)
                eval_logger.error(
                    f"Tasks were not found: {missing}\n"
lintangsutawika's avatar
lintangsutawika committed
302
                    f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
baberabb's avatar
baberabb committed
303
304
                )
                raise ValueError(
305
                    f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
baberabb's avatar
baberabb committed
306
                )
lintangsutawika's avatar
lintangsutawika committed
307

308
309
    if args.output_path:
        path = Path(args.output_path)
Lintang Sutawika's avatar
Lintang Sutawika committed
310
        # check if file or 'dir/results.json' exists
311
312
313
314
        if path.is_file():
            raise FileExistsError(f"File already exists at {path}")
        output_path_file = path.joinpath(DEFAULT_RESULTS_FILE)
        if output_path_file.is_file():
315
            eval_logger.warning(
316
                f"File {output_path_file} already exists. Results will be overwritten."
317
318
319
320
321
322
323
324
325
            )
        # if path json then get parent dir
        elif path.suffix in (".json", ".jsonl"):
            output_path_file = path
            path.parent.mkdir(parents=True, exist_ok=True)
            path = path.parent
        else:
            path.mkdir(parents=True, exist_ok=True)

326
327
    # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
    if args.trust_remote_code:
328
        os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
329
330
331
332
        args.model_args = (
            args.model_args
            + f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
        )
333

lintangsutawika's avatar
lintangsutawika committed
334
    eval_logger.info(f"Selected Tasks: {task_names}")
lintangsutawika's avatar
lintangsutawika committed
335

336
337
338
339
    request_caching_args = request_caching_arg_to_dict(
        cache_requests=args.cache_requests
    )

340
341
342
343
344
345
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        tasks=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
346
        max_batch_size=args.max_batch_size,
347
        device=args.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
348
        use_cache=args.use_cache,
349
350
        limit=args.limit,
        check_integrity=args.check_integrity,
351
        write_out=args.write_out,
352
        log_samples=args.log_samples,
lintangsutawika's avatar
lintangsutawika committed
353
        gen_kwargs=args.gen_kwargs,
354
        task_manager=task_manager,
355
        verbosity=args.verbosity,
Baber Abbasi's avatar
Baber Abbasi committed
356
        predict_only=args.predict_only,
357
358
359
        random_seed=args.seed[0],
        numpy_random_seed=args.seed[1],
        torch_random_seed=args.seed[2],
360
        **request_caching_args,
361
    )
362

363
    if results is not None:
364
365
        if args.log_samples:
            samples = results.pop("samples")
366
367
368
        dumped = json.dumps(
            results, indent=2, default=_handle_non_serializable, ensure_ascii=False
        )
369
370
        if args.show_config:
            print(dumped)
371

372
373
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

374
375
376
377
378
379
380
381
382
383
        # Add W&B logging
        if args.wandb_args:
            try:
                wandb_logger.post_init(results)
                wandb_logger.log_eval_result()
                if args.log_samples:
                    wandb_logger.log_eval_samples(samples)
            except Exception as e:
                eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

384
        if args.output_path:
385
            output_path_file.open("w", encoding="utf-8").write(dumped)
386

387
388
389
            if args.log_samples:
                for task_name, config in results["configs"].items():
                    output_name = "{}_{}".format(
390
391
                        re.sub(r"[\"<>:/\|\\?\*\[\]]+", "__", args.model_args),
                        task_name,
lintangsutawika's avatar
lintangsutawika committed
392
                    )
393
                    filename = path.joinpath(f"{output_name}.jsonl")
394
                    samples_dumped = json.dumps(
395
396
397
398
                        samples[task_name],
                        indent=2,
                        default=_handle_non_serializable,
                        ensure_ascii=False,
399
                    )
400
                    filename.write_text(samples_dumped, encoding="utf-8")
lintangsutawika's avatar
lintangsutawika committed
401

402
        print(
403
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
404
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
405
        )
406
        print(make_table(results))
lintangsutawika's avatar
lintangsutawika committed
407
        if "groups" in results:
408
            print(make_table(results, "groups"))
Jason Phang's avatar
lib  
Jason Phang committed
409

410
411
412
413
        if args.wandb_args:
            # Tear down wandb run once all the logging is done.
            wandb_logger.run.finish()

414

Jason Phang's avatar
Jason Phang committed
415
if __name__ == "__main__":
haileyschoelkopf's avatar
haileyschoelkopf committed
416
    cli_evaluate()