__main__.py 13.7 KB
Newer Older
1
2
3
import argparse
import json
import logging
lintangsutawika's avatar
lintangsutawika committed
4
import os
lintangsutawika's avatar
lintangsutawika committed
5
import re
6
import sys
7
from functools import partial
8
from pathlib import Path
haileyschoelkopf's avatar
haileyschoelkopf committed
9
from typing import Union
Leo Gao's avatar
Leo Gao committed
10

11
12
import numpy as np

13
from lm_eval import evaluator, utils
14
from lm_eval.evaluator import request_caching_arg_to_dict
15
from lm_eval.logging_utils import WandbLogger
16
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
17
from lm_eval.utils import make_table, simple_parse_args_string
lintangsutawika's avatar
format  
lintangsutawika committed
18

19

20
def _handle_non_serializable(o):
21
    if isinstance(o, np.int64) or isinstance(o, np.int32):
22
23
24
        return int(o)
    elif isinstance(o, set):
        return list(o)
25
26
    else:
        return str(o)
Fabrizio Milo's avatar
Fabrizio Milo committed
27

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items != max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )

    return items


haileyschoelkopf's avatar
haileyschoelkopf committed
53
def parse_eval_args() -> argparse.Namespace:
lintangsutawika's avatar
lintangsutawika committed
54
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
Baber Abbasi's avatar
Baber Abbasi committed
55
    parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
lintangsutawika's avatar
lintangsutawika committed
56
57
    parser.add_argument(
        "--tasks",
Baber Abbasi's avatar
Baber Abbasi committed
58
        "-t",
lintangsutawika's avatar
lintangsutawika committed
59
        default=None,
60
        metavar="task1,task2",
lintangsutawika's avatar
lintangsutawika committed
61
        help="To get full list of tasks, use the command lm-eval --tasks list",
lintangsutawika's avatar
lintangsutawika committed
62
    )
63
64
    parser.add_argument(
        "--model_args",
Baber Abbasi's avatar
Baber Abbasi committed
65
        "-a",
66
        default="",
67
        help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
68
    )
lintangsutawika's avatar
lintangsutawika committed
69
    parser.add_argument(
70
        "--num_fewshot",
Baber Abbasi's avatar
Baber Abbasi committed
71
        "-f",
72
        type=int,
73
        default=None,
74
        metavar="N",
75
76
        help="Number of examples in few-shot context",
    )
77
78
    parser.add_argument(
        "--batch_size",
Baber Abbasi's avatar
Baber Abbasi committed
79
        "-b",
80
81
82
83
84
        type=str,
        default=1,
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
lintangsutawika's avatar
lintangsutawika committed
85
86
87
88
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
89
90
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
lintangsutawika's avatar
lintangsutawika committed
91
    )
92
93
94
95
    parser.add_argument(
        "--device",
        type=str,
        default=None,
96
        help="Device to use (e.g. cuda, cuda:0, cpu).",
97
98
99
    )
    parser.add_argument(
        "--output_path",
Baber Abbasi's avatar
Baber Abbasi committed
100
        "-o",
101
102
        default=None,
        type=str,
103
        metavar="DIR|DIR/file.json",
104
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
105
    )
lintangsutawika's avatar
lintangsutawika committed
106
107
    parser.add_argument(
        "--limit",
Baber Abbasi's avatar
Baber Abbasi committed
108
        "-L",
lintangsutawika's avatar
lintangsutawika committed
109
110
        type=float,
        default=None,
111
        metavar="N|0<N<1",
lintangsutawika's avatar
lintangsutawika committed
112
113
114
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
115
116
    parser.add_argument(
        "--use_cache",
Baber Abbasi's avatar
Baber Abbasi committed
117
        "-c",
118
119
        type=str,
        default=None,
120
        metavar="DIR",
121
122
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
123
124
125
126
127
128
129
    parser.add_argument(
        "--cache_requests",
        type=str,
        default=None,
        choices=["true", "refresh", "delete"],
        help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
    )
130
131
132
133
    parser.add_argument("--decontamination_ngrams_path", default=None)  # TODO: not used
    parser.add_argument(
        "--check_integrity",
        action="store_true",
134
        help="Whether to run the relevant part of the test suite for the tasks.",
135
136
137
    )
    parser.add_argument(
        "--write_out",
Baber Abbasi's avatar
Baber Abbasi committed
138
        "-w",
139
140
        action="store_true",
        default=False,
141
        help="Prints the prompt for the first few documents.",
142
143
144
    )
    parser.add_argument(
        "--log_samples",
Baber Abbasi's avatar
Baber Abbasi committed
145
        "-s",
146
147
        action="store_true",
        default=False,
148
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
149
    )
150
151
152
153
154
155
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
156
157
158
159
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
160
        metavar="DIR",
161
162
        help="Additional path to include if there are external tasks to include.",
    )
163
164
    parser.add_argument(
        "--gen_kwargs",
165
        default=None,
USVSN Sai Prashanth's avatar
USVSN Sai Prashanth committed
166
167
        help=(
            "String arguments for model generation on greedy_until tasks,"
168
            " e.g. `temperature=0,top_k=0,top_p=0`."
lintangsutawika's avatar
lintangsutawika committed
169
170
171
        ),
    )
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
172
        "--verbosity",
Baber Abbasi's avatar
Baber Abbasi committed
173
174
        "-v",
        type=str.upper,
lintangsutawika's avatar
lintangsutawika committed
175
        default="INFO",
176
177
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
178
    )
179
180
181
182
183
    parser.add_argument(
        "--wandb_args",
        default="",
        help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
    )
Baber Abbasi's avatar
Baber Abbasi committed
184
185
186
187
188
189
190
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
191
192
193
194
195
196
197
198
199
200
201
202
203
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3),
        default="0,1234,1234",  # for backward compatibility
        help=(
            "Set seed for python's random, numpy and torch.\n"
            "Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
            "or a single integer to set the same seed for all three.\n"
            "The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
            "E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all three seeds to 42."
        ),
    )
204
205
206
207
208
209
    parser.add_argument(
        "--trust_remote_code",
        default=True,
        help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
    )

Jason Phang's avatar
Jason Phang committed
210
211
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
212

haileyschoelkopf's avatar
haileyschoelkopf committed
213
214
215
216
217
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
    if not args:
        # we allow for args to be passed externally, else we parse them ourselves
        args = parse_eval_args()

218
    if args.wandb_args:
219
        wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
220

221
    eval_logger = utils.eval_logger
lintangsutawika's avatar
lintangsutawika committed
222
    eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
223
    eval_logger.info(f"Verbosity set to {args.verbosity}")
haileyschoelkopf's avatar
haileyschoelkopf committed
224
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
Fabrizio Milo's avatar
Fabrizio Milo committed
225

Baber Abbasi's avatar
Baber Abbasi committed
226
227
228
229
230
    if args.predict_only:
        args.log_samples = True
    if (args.log_samples or args.predict_only) and not args.output_path:
        assert args.output_path, "Specify --output_path"

231
    initialize_tasks(args.verbosity)
232
    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
233

Leo Gao's avatar
Leo Gao committed
234
    if args.limit:
lintangsutawika's avatar
lintangsutawika committed
235
236
237
        eval_logger.warning(
            " --limit SHOULD ONLY BE USED FOR TESTING."
            "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
Fabrizio Milo's avatar
Fabrizio Milo committed
238
        )
lintangsutawika's avatar
lintangsutawika committed
239
240
    if args.include_path is not None:
        eval_logger.info(f"Including path: {args.include_path}")
241
        include_path(args.include_path)
lintangsutawika's avatar
lintangsutawika committed
242

243
    if args.tasks is None:
244
245
        eval_logger.error("Need to specify task to evaluate.")
        sys.exit()
246
    elif args.tasks == "list":
lintangsutawika's avatar
lintangsutawika committed
247
        eval_logger.info(
Lintang Sutawika's avatar
Lintang Sutawika committed
248
            "Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
lintangsutawika's avatar
lintangsutawika committed
249
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
250
        sys.exit()
Jason Phang's avatar
Jason Phang committed
251
    else:
252
253
        if os.path.isdir(args.tasks):
            import glob
254
255

            task_names = []
256
257
            yaml_path = os.path.join(args.tasks, "*.yaml")
            for yaml_file in glob.glob(yaml_path):
lintangsutawika's avatar
lintangsutawika committed
258
                config = utils.load_yaml_config(yaml_file)
259
260
                task_names.append(config)
        else:
261
262
263
            task_list = args.tasks.split(",")
            task_names = task_manager.match_tasks(task_list)
            for task in [task for task in task_list if task not in task_names]:
264
                if os.path.isfile(task):
lintangsutawika's avatar
lintangsutawika committed
265
                    config = utils.load_yaml_config(task)
266
                    task_names.append(config)
267
            task_missing = [
268
                task for task in task_list if task not in task_names and "*" not in task
269
            ]  # we don't want errors if a wildcard ("*") task name was used
lintangsutawika's avatar
lintangsutawika committed
270

baberabb's avatar
baberabb committed
271
272
273
274
            if task_missing:
                missing = ", ".join(task_missing)
                eval_logger.error(
                    f"Tasks were not found: {missing}\n"
lintangsutawika's avatar
lintangsutawika committed
275
                    f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
baberabb's avatar
baberabb committed
276
277
                )
                raise ValueError(
278
                    f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
baberabb's avatar
baberabb committed
279
                )
lintangsutawika's avatar
lintangsutawika committed
280

281
282
    if args.output_path:
        path = Path(args.output_path)
Lintang Sutawika's avatar
Lintang Sutawika committed
283
        # check if file or 'dir/results.json' exists
baberabb's avatar
baberabb committed
284
        if path.is_file() or Path(args.output_path).joinpath("results.json").is_file():
285
286
287
            eval_logger.warning(
                f"File already exists at {path}. Results will be overwritten."
            )
lintangsutawika's avatar
lintangsutawika committed
288
            output_path_file = path.joinpath("results.json")
289
290
291
292
293
294
295
296
297
298
            assert not path.is_file(), "File already exists"
        # if path json then get parent dir
        elif path.suffix in (".json", ".jsonl"):
            output_path_file = path
            path.parent.mkdir(parents=True, exist_ok=True)
            path = path.parent
        else:
            path.mkdir(parents=True, exist_ok=True)
            output_path_file = path.joinpath("results.json")

299
300
301
302
303
304
305
306
307
308
    # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
    if args.trust_remote_code:
        os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = (
            args.trust_remote_code if args.trust_remote_code else True
        )
        args.model_args = (
            args.model_args
            + f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
        )

lintangsutawika's avatar
lintangsutawika committed
309
    eval_logger.info(f"Selected Tasks: {task_names}")
310
    eval_logger.info("Loading selected tasks...")
311

312
313
314
315
    request_caching_args = request_caching_arg_to_dict(
        cache_requests=args.cache_requests
    )

316
317
318
319
320
321
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        tasks=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
322
        max_batch_size=args.max_batch_size,
323
        device=args.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
324
        use_cache=args.use_cache,
325
326
327
        limit=args.limit,
        decontamination_ngrams_path=args.decontamination_ngrams_path,
        check_integrity=args.check_integrity,
328
        write_out=args.write_out,
329
        log_samples=args.log_samples,
lintangsutawika's avatar
lintangsutawika committed
330
        gen_kwargs=args.gen_kwargs,
331
        task_manager=task_manager,
Baber Abbasi's avatar
Baber Abbasi committed
332
        predict_only=args.predict_only,
333
        **request_caching_args,
334
335
336
        random_seed=args.seed[0],
        numpy_random_seed=args.seed[1],
        torch_random_seed=args.seed[2],
337
    )
338

339
    if results is not None:
340
341
        if args.log_samples:
            samples = results.pop("samples")
342
343
344
        dumped = json.dumps(
            results, indent=2, default=_handle_non_serializable, ensure_ascii=False
        )
345
346
        if args.show_config:
            print(dumped)
347

348
349
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

350
351
352
353
354
355
356
357
358
359
        # Add W&B logging
        if args.wandb_args:
            try:
                wandb_logger.post_init(results)
                wandb_logger.log_eval_result()
                if args.log_samples:
                    wandb_logger.log_eval_samples(samples)
            except Exception as e:
                eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

360
        if args.output_path:
361
            output_path_file.open("w", encoding="utf-8").write(dumped)
362

363
364
365
            if args.log_samples:
                for task_name, config in results["configs"].items():
                    output_name = "{}_{}".format(
lintangsutawika's avatar
lintangsutawika committed
366
                        re.sub("/|=", "__", args.model_args), task_name
lintangsutawika's avatar
lintangsutawika committed
367
                    )
368
                    filename = path.joinpath(f"{output_name}.jsonl")
369
                    samples_dumped = json.dumps(
370
371
372
373
                        samples[task_name],
                        indent=2,
                        default=_handle_non_serializable,
                        ensure_ascii=False,
374
                    )
375
                    filename.write_text(samples_dumped, encoding="utf-8")
lintangsutawika's avatar
lintangsutawika committed
376

377
        print(
378
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
379
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
380
        )
381
        print(make_table(results))
lintangsutawika's avatar
lintangsutawika committed
382
        if "groups" in results:
383
            print(make_table(results, "groups"))
Jason Phang's avatar
lib  
Jason Phang committed
384

385
386
387
388
        if args.wandb_args:
            # Tear down wandb run once all the logging is done.
            wandb_logger.run.finish()

389

Jason Phang's avatar
Jason Phang committed
390
if __name__ == "__main__":
haileyschoelkopf's avatar
haileyschoelkopf committed
391
    cli_evaluate()