__main__.py 13.1 KB
Newer Older
1
2
3
import argparse
import json
import logging
lintangsutawika's avatar
lintangsutawika committed
4
import os
lintangsutawika's avatar
lintangsutawika committed
5
import re
6
import sys
7
from functools import partial
8
from pathlib import Path
haileyschoelkopf's avatar
haileyschoelkopf committed
9
from typing import Union
Leo Gao's avatar
Leo Gao committed
10

11
12
import numpy as np

13
from lm_eval import evaluator, utils
14
from lm_eval.evaluator import request_caching_arg_to_dict
15
from lm_eval.logging_utils import WandbLogger
16
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
17
from lm_eval.utils import make_table, simple_parse_args_string
lintangsutawika's avatar
format  
lintangsutawika committed
18

19

20
def _handle_non_serializable(o):
21
    if isinstance(o, np.int64) or isinstance(o, np.int32):
22
23
24
        return int(o)
    elif isinstance(o, set):
        return list(o)
25
26
    else:
        return str(o)
Fabrizio Milo's avatar
Fabrizio Milo committed
27

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items != max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )

    return items


haileyschoelkopf's avatar
haileyschoelkopf committed
53
def parse_eval_args() -> argparse.Namespace:
lintangsutawika's avatar
lintangsutawika committed
54
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
Baber Abbasi's avatar
Baber Abbasi committed
55
    parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
lintangsutawika's avatar
lintangsutawika committed
56
57
    parser.add_argument(
        "--tasks",
Baber Abbasi's avatar
Baber Abbasi committed
58
        "-t",
lintangsutawika's avatar
lintangsutawika committed
59
        default=None,
60
        metavar="task1,task2",
lintangsutawika's avatar
lintangsutawika committed
61
        help="To get full list of tasks, use the command lm-eval --tasks list",
lintangsutawika's avatar
lintangsutawika committed
62
    )
63
64
    parser.add_argument(
        "--model_args",
Baber Abbasi's avatar
Baber Abbasi committed
65
        "-a",
66
        default="",
67
        help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
68
    )
lintangsutawika's avatar
lintangsutawika committed
69
    parser.add_argument(
70
        "--num_fewshot",
Baber Abbasi's avatar
Baber Abbasi committed
71
        "-f",
72
        type=int,
73
        default=None,
74
        metavar="N",
75
76
        help="Number of examples in few-shot context",
    )
77
78
    parser.add_argument(
        "--batch_size",
Baber Abbasi's avatar
Baber Abbasi committed
79
        "-b",
80
81
82
83
84
        type=str,
        default=1,
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
lintangsutawika's avatar
lintangsutawika committed
85
86
87
88
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
89
90
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
lintangsutawika's avatar
lintangsutawika committed
91
    )
92
93
94
95
    parser.add_argument(
        "--device",
        type=str,
        default=None,
96
        help="Device to use (e.g. cuda, cuda:0, cpu).",
97
98
99
    )
    parser.add_argument(
        "--output_path",
Baber Abbasi's avatar
Baber Abbasi committed
100
        "-o",
101
102
        default=None,
        type=str,
103
        metavar="DIR|DIR/file.json",
104
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
105
    )
lintangsutawika's avatar
lintangsutawika committed
106
107
    parser.add_argument(
        "--limit",
Baber Abbasi's avatar
Baber Abbasi committed
108
        "-L",
lintangsutawika's avatar
lintangsutawika committed
109
110
        type=float,
        default=None,
111
        metavar="N|0<N<1",
lintangsutawika's avatar
lintangsutawika committed
112
113
114
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
115
116
    parser.add_argument(
        "--use_cache",
Baber Abbasi's avatar
Baber Abbasi committed
117
        "-c",
118
119
        type=str,
        default=None,
120
        metavar="DIR",
121
122
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
123
124
125
126
127
128
129
    parser.add_argument(
        "--cache_requests",
        type=str,
        default=None,
        choices=["true", "refresh", "delete"],
        help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
    )
130
131
132
133
    parser.add_argument("--decontamination_ngrams_path", default=None)  # TODO: not used
    parser.add_argument(
        "--check_integrity",
        action="store_true",
134
        help="Whether to run the relevant part of the test suite for the tasks.",
135
136
137
    )
    parser.add_argument(
        "--write_out",
Baber Abbasi's avatar
Baber Abbasi committed
138
        "-w",
139
140
        action="store_true",
        default=False,
141
        help="Prints the prompt for the first few documents.",
142
143
144
    )
    parser.add_argument(
        "--log_samples",
Baber Abbasi's avatar
Baber Abbasi committed
145
        "-s",
146
147
        action="store_true",
        default=False,
148
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
149
    )
150
151
152
153
154
155
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
156
157
158
159
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
160
        metavar="DIR",
161
162
        help="Additional path to include if there are external tasks to include.",
    )
163
164
    parser.add_argument(
        "--gen_kwargs",
165
        default=None,
USVSN Sai Prashanth's avatar
USVSN Sai Prashanth committed
166
167
        help=(
            "String arguments for model generation on greedy_until tasks,"
168
            " e.g. `temperature=0,top_k=0,top_p=0`."
lintangsutawika's avatar
lintangsutawika committed
169
170
171
        ),
    )
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
172
        "--verbosity",
Baber Abbasi's avatar
Baber Abbasi committed
173
174
        "-v",
        type=str.upper,
lintangsutawika's avatar
lintangsutawika committed
175
        default="INFO",
176
177
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
178
    )
179
180
181
182
183
    parser.add_argument(
        "--wandb_args",
        default="",
        help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
    )
Baber Abbasi's avatar
Baber Abbasi committed
184
185
186
187
188
189
190
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
191
192
193
194
195
196
197
198
199
200
201
202
203
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3),
        default="0,1234,1234",  # for backward compatibility
        help=(
            "Set seed for python's random, numpy and torch.\n"
            "Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
            "or a single integer to set the same seed for all three.\n"
            "The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
            "E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all three seeds to 42."
        ),
    )
Jason Phang's avatar
Jason Phang committed
204
205
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
206

haileyschoelkopf's avatar
haileyschoelkopf committed
207
208
209
210
211
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
    if not args:
        # we allow for args to be passed externally, else we parse them ourselves
        args = parse_eval_args()

212
    if args.wandb_args:
213
        wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
214

215
    eval_logger = utils.eval_logger
lintangsutawika's avatar
lintangsutawika committed
216
    eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
217
    eval_logger.info(f"Verbosity set to {args.verbosity}")
haileyschoelkopf's avatar
haileyschoelkopf committed
218
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
Fabrizio Milo's avatar
Fabrizio Milo committed
219

Baber Abbasi's avatar
Baber Abbasi committed
220
221
222
223
224
    if args.predict_only:
        args.log_samples = True
    if (args.log_samples or args.predict_only) and not args.output_path:
        assert args.output_path, "Specify --output_path"

225
    initialize_tasks(args.verbosity)
226
    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
227

Leo Gao's avatar
Leo Gao committed
228
    if args.limit:
lintangsutawika's avatar
lintangsutawika committed
229
230
231
        eval_logger.warning(
            " --limit SHOULD ONLY BE USED FOR TESTING."
            "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
Fabrizio Milo's avatar
Fabrizio Milo committed
232
        )
lintangsutawika's avatar
lintangsutawika committed
233
234
    if args.include_path is not None:
        eval_logger.info(f"Including path: {args.include_path}")
235
        include_path(args.include_path)
lintangsutawika's avatar
lintangsutawika committed
236

237
    if args.tasks is None:
238
239
        eval_logger.error("Need to specify task to evaluate.")
        sys.exit()
240
    elif args.tasks == "list":
lintangsutawika's avatar
lintangsutawika committed
241
        eval_logger.info(
Lintang Sutawika's avatar
Lintang Sutawika committed
242
            "Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
lintangsutawika's avatar
lintangsutawika committed
243
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
244
        sys.exit()
Jason Phang's avatar
Jason Phang committed
245
    else:
246
247
        if os.path.isdir(args.tasks):
            import glob
248
249

            task_names = []
250
251
            yaml_path = os.path.join(args.tasks, "*.yaml")
            for yaml_file in glob.glob(yaml_path):
lintangsutawika's avatar
lintangsutawika committed
252
                config = utils.load_yaml_config(yaml_file)
253
254
                task_names.append(config)
        else:
255
256
257
            task_list = args.tasks.split(",")
            task_names = task_manager.match_tasks(task_list)
            for task in [task for task in task_list if task not in task_names]:
258
                if os.path.isfile(task):
lintangsutawika's avatar
lintangsutawika committed
259
                    config = utils.load_yaml_config(task)
260
                    task_names.append(config)
261
            task_missing = [
262
                task for task in task_list if task not in task_names and "*" not in task
263
            ]  # we don't want errors if a wildcard ("*") task name was used
lintangsutawika's avatar
lintangsutawika committed
264

baberabb's avatar
baberabb committed
265
266
267
268
            if task_missing:
                missing = ", ".join(task_missing)
                eval_logger.error(
                    f"Tasks were not found: {missing}\n"
lintangsutawika's avatar
lintangsutawika committed
269
                    f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
baberabb's avatar
baberabb committed
270
271
                )
                raise ValueError(
272
                    f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
baberabb's avatar
baberabb committed
273
                )
lintangsutawika's avatar
lintangsutawika committed
274

275
276
    if args.output_path:
        path = Path(args.output_path)
Lintang Sutawika's avatar
Lintang Sutawika committed
277
        # check if file or 'dir/results.json' exists
baberabb's avatar
baberabb committed
278
        if path.is_file() or Path(args.output_path).joinpath("results.json").is_file():
279
280
281
            eval_logger.warning(
                f"File already exists at {path}. Results will be overwritten."
            )
lintangsutawika's avatar
lintangsutawika committed
282
            output_path_file = path.joinpath("results.json")
283
284
285
286
287
288
289
290
291
292
            assert not path.is_file(), "File already exists"
        # if path json then get parent dir
        elif path.suffix in (".json", ".jsonl"):
            output_path_file = path
            path.parent.mkdir(parents=True, exist_ok=True)
            path = path.parent
        else:
            path.mkdir(parents=True, exist_ok=True)
            output_path_file = path.joinpath("results.json")

lintangsutawika's avatar
lintangsutawika committed
293
    eval_logger.info(f"Selected Tasks: {task_names}")
294
    eval_logger.info("Loading selected tasks...")
295

296
297
298
299
    request_caching_args = request_caching_arg_to_dict(
        cache_requests=args.cache_requests
    )

300
301
302
303
304
305
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        tasks=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
306
        max_batch_size=args.max_batch_size,
307
        device=args.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
308
        use_cache=args.use_cache,
309
310
311
        limit=args.limit,
        decontamination_ngrams_path=args.decontamination_ngrams_path,
        check_integrity=args.check_integrity,
312
        write_out=args.write_out,
313
        log_samples=args.log_samples,
lintangsutawika's avatar
lintangsutawika committed
314
        gen_kwargs=args.gen_kwargs,
315
        task_manager=task_manager,
Baber Abbasi's avatar
Baber Abbasi committed
316
        predict_only=args.predict_only,
317
        **request_caching_args,
318
319
320
        random_seed=args.seed[0],
        numpy_random_seed=args.seed[1],
        torch_random_seed=args.seed[2],
321
    )
322

323
    if results is not None:
324
325
        if args.log_samples:
            samples = results.pop("samples")
326
327
328
        dumped = json.dumps(
            results, indent=2, default=_handle_non_serializable, ensure_ascii=False
        )
329
330
        if args.show_config:
            print(dumped)
331

332
333
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

334
335
336
337
338
339
340
341
342
343
        # Add W&B logging
        if args.wandb_args:
            try:
                wandb_logger.post_init(results)
                wandb_logger.log_eval_result()
                if args.log_samples:
                    wandb_logger.log_eval_samples(samples)
            except Exception as e:
                eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

344
        if args.output_path:
345
            output_path_file.open("w", encoding="utf-8").write(dumped)
346

347
348
349
            if args.log_samples:
                for task_name, config in results["configs"].items():
                    output_name = "{}_{}".format(
lintangsutawika's avatar
lintangsutawika committed
350
                        re.sub("/|=", "__", args.model_args), task_name
lintangsutawika's avatar
lintangsutawika committed
351
                    )
352
                    filename = path.joinpath(f"{output_name}.jsonl")
353
                    samples_dumped = json.dumps(
354
355
356
357
                        samples[task_name],
                        indent=2,
                        default=_handle_non_serializable,
                        ensure_ascii=False,
358
                    )
359
                    filename.write_text(samples_dumped, encoding="utf-8")
lintangsutawika's avatar
lintangsutawika committed
360

361
        print(
362
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
363
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
364
        )
365
        print(make_table(results))
lintangsutawika's avatar
lintangsutawika committed
366
        if "groups" in results:
367
            print(make_table(results, "groups"))
Jason Phang's avatar
lib  
Jason Phang committed
368

369
370
371
372
        if args.wandb_args:
            # Tear down wandb run once all the logging is done.
            wandb_logger.run.finish()

373

Jason Phang's avatar
Jason Phang committed
374
if __name__ == "__main__":
haileyschoelkopf's avatar
haileyschoelkopf committed
375
    cli_evaluate()