__main__.py 12.6 KB
Newer Older
1
2
3
import argparse
import json
import logging
lintangsutawika's avatar
lintangsutawika committed
4
import os
lintangsutawika's avatar
lintangsutawika committed
5
import re
6
import sys
7
from functools import partial
8
from pathlib import Path
haileyschoelkopf's avatar
haileyschoelkopf committed
9
from typing import Union
Leo Gao's avatar
Leo Gao committed
10

11
12
import numpy as np

13
from lm_eval import evaluator, utils
14
from lm_eval.logging_utils import WandbLogger
15
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
16
from lm_eval.utils import make_table
lintangsutawika's avatar
format  
lintangsutawika committed
17

18

19
def _handle_non_serializable(o):
20
    if isinstance(o, np.int64) or isinstance(o, np.int32):
21
22
23
        return int(o)
    elif isinstance(o, set):
        return list(o)
24
25
    else:
        return str(o)
Fabrizio Milo's avatar
Fabrizio Milo committed
26

27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items != max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )

    return items


haileyschoelkopf's avatar
haileyschoelkopf committed
52
def parse_eval_args() -> argparse.Namespace:
lintangsutawika's avatar
lintangsutawika committed
53
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
Baber Abbasi's avatar
Baber Abbasi committed
54
    parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
lintangsutawika's avatar
lintangsutawika committed
55
56
    parser.add_argument(
        "--tasks",
Baber Abbasi's avatar
Baber Abbasi committed
57
        "-t",
lintangsutawika's avatar
lintangsutawika committed
58
        default=None,
59
        metavar="task1,task2",
lintangsutawika's avatar
lintangsutawika committed
60
        help="To get full list of tasks, use the command lm-eval --tasks list",
lintangsutawika's avatar
lintangsutawika committed
61
    )
62
63
    parser.add_argument(
        "--model_args",
Baber Abbasi's avatar
Baber Abbasi committed
64
        "-a",
65
        default="",
66
        help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
67
    )
lintangsutawika's avatar
lintangsutawika committed
68
    parser.add_argument(
69
        "--num_fewshot",
Baber Abbasi's avatar
Baber Abbasi committed
70
        "-f",
71
        type=int,
72
        default=None,
73
        metavar="N",
74
75
        help="Number of examples in few-shot context",
    )
76
77
    parser.add_argument(
        "--batch_size",
Baber Abbasi's avatar
Baber Abbasi committed
78
        "-b",
79
80
81
82
83
        type=str,
        default=1,
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
lintangsutawika's avatar
lintangsutawika committed
84
85
86
87
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
88
89
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
lintangsutawika's avatar
lintangsutawika committed
90
    )
91
92
93
94
    parser.add_argument(
        "--device",
        type=str,
        default=None,
95
        help="Device to use (e.g. cuda, cuda:0, cpu).",
96
97
98
    )
    parser.add_argument(
        "--output_path",
Baber Abbasi's avatar
Baber Abbasi committed
99
        "-o",
100
101
        default=None,
        type=str,
102
        metavar="DIR|DIR/file.json",
103
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
104
    )
lintangsutawika's avatar
lintangsutawika committed
105
106
    parser.add_argument(
        "--limit",
Baber Abbasi's avatar
Baber Abbasi committed
107
        "-L",
lintangsutawika's avatar
lintangsutawika committed
108
109
        type=float,
        default=None,
110
        metavar="N|0<N<1",
lintangsutawika's avatar
lintangsutawika committed
111
112
113
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
114
115
    parser.add_argument(
        "--use_cache",
Baber Abbasi's avatar
Baber Abbasi committed
116
        "-c",
117
118
        type=str,
        default=None,
119
        metavar="DIR",
120
121
122
123
124
125
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
    parser.add_argument("--decontamination_ngrams_path", default=None)  # TODO: not used
    parser.add_argument(
        "--check_integrity",
        action="store_true",
126
        help="Whether to run the relevant part of the test suite for the tasks.",
127
128
129
    )
    parser.add_argument(
        "--write_out",
Baber Abbasi's avatar
Baber Abbasi committed
130
        "-w",
131
132
        action="store_true",
        default=False,
133
        help="Prints the prompt for the first few documents.",
134
135
136
    )
    parser.add_argument(
        "--log_samples",
Baber Abbasi's avatar
Baber Abbasi committed
137
        "-s",
138
139
        action="store_true",
        default=False,
140
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
141
    )
142
143
144
145
146
147
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
148
149
150
151
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
152
        metavar="DIR",
153
154
        help="Additional path to include if there are external tasks to include.",
    )
155
156
    parser.add_argument(
        "--gen_kwargs",
157
        default=None,
USVSN Sai Prashanth's avatar
USVSN Sai Prashanth committed
158
159
        help=(
            "String arguments for model generation on greedy_until tasks,"
160
            " e.g. `temperature=0,top_k=0,top_p=0`."
lintangsutawika's avatar
lintangsutawika committed
161
162
163
        ),
    )
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
164
        "--verbosity",
Baber Abbasi's avatar
Baber Abbasi committed
165
166
        "-v",
        type=str.upper,
lintangsutawika's avatar
lintangsutawika committed
167
        default="INFO",
168
169
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
170
    )
171
172
173
174
175
    parser.add_argument(
        "--wandb_args",
        default="",
        help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
    )
Baber Abbasi's avatar
Baber Abbasi committed
176
177
178
179
180
181
182
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
183
184
185
186
187
188
189
190
191
192
193
194
195
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3),
        default="0,1234,1234",  # for backward compatibility
        help=(
            "Set seed for python's random, numpy and torch.\n"
            "Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
            "or a single integer to set the same seed for all three.\n"
            "The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
            "E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all three seeds to 42."
        ),
    )
Jason Phang's avatar
Jason Phang committed
196
197
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
198

haileyschoelkopf's avatar
haileyschoelkopf committed
199
200
201
202
203
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
    if not args:
        # we allow for args to be passed externally, else we parse them ourselves
        args = parse_eval_args()

204
205
206
    if args.wandb_args:
        wandb_logger = WandbLogger(args)

207
    eval_logger = utils.eval_logger
lintangsutawika's avatar
lintangsutawika committed
208
    eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
209
    eval_logger.info(f"Verbosity set to {args.verbosity}")
haileyschoelkopf's avatar
haileyschoelkopf committed
210
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
Fabrizio Milo's avatar
Fabrizio Milo committed
211

Baber Abbasi's avatar
Baber Abbasi committed
212
213
214
215
216
    if args.predict_only:
        args.log_samples = True
    if (args.log_samples or args.predict_only) and not args.output_path:
        assert args.output_path, "Specify --output_path"

217
    initialize_tasks(args.verbosity)
218
    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
219

Leo Gao's avatar
Leo Gao committed
220
    if args.limit:
lintangsutawika's avatar
lintangsutawika committed
221
222
223
        eval_logger.warning(
            " --limit SHOULD ONLY BE USED FOR TESTING."
            "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
Fabrizio Milo's avatar
Fabrizio Milo committed
224
        )
lintangsutawika's avatar
lintangsutawika committed
225
226
    if args.include_path is not None:
        eval_logger.info(f"Including path: {args.include_path}")
227
        include_path(args.include_path)
lintangsutawika's avatar
lintangsutawika committed
228

229
    if args.tasks is None:
230
231
        eval_logger.error("Need to specify task to evaluate.")
        sys.exit()
232
    elif args.tasks == "list":
lintangsutawika's avatar
lintangsutawika committed
233
        eval_logger.info(
Lintang Sutawika's avatar
Lintang Sutawika committed
234
            "Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
lintangsutawika's avatar
lintangsutawika committed
235
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
236
        sys.exit()
Jason Phang's avatar
Jason Phang committed
237
    else:
238
239
        if os.path.isdir(args.tasks):
            import glob
240
241

            task_names = []
242
243
            yaml_path = os.path.join(args.tasks, "*.yaml")
            for yaml_file in glob.glob(yaml_path):
lintangsutawika's avatar
lintangsutawika committed
244
                config = utils.load_yaml_config(yaml_file)
245
246
                task_names.append(config)
        else:
247
248
249
            task_list = args.tasks.split(",")
            task_names = task_manager.match_tasks(task_list)
            for task in [task for task in task_list if task not in task_names]:
250
                if os.path.isfile(task):
lintangsutawika's avatar
lintangsutawika committed
251
                    config = utils.load_yaml_config(task)
252
                    task_names.append(config)
253
            task_missing = [
254
                task for task in task_list if task not in task_names and "*" not in task
255
            ]  # we don't want errors if a wildcard ("*") task name was used
lintangsutawika's avatar
lintangsutawika committed
256

baberabb's avatar
baberabb committed
257
258
259
260
            if task_missing:
                missing = ", ".join(task_missing)
                eval_logger.error(
                    f"Tasks were not found: {missing}\n"
lintangsutawika's avatar
lintangsutawika committed
261
                    f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
baberabb's avatar
baberabb committed
262
263
                )
                raise ValueError(
264
                    f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
baberabb's avatar
baberabb committed
265
                )
lintangsutawika's avatar
lintangsutawika committed
266

267
268
    if args.output_path:
        path = Path(args.output_path)
Lintang Sutawika's avatar
Lintang Sutawika committed
269
        # check if file or 'dir/results.json' exists
baberabb's avatar
baberabb committed
270
        if path.is_file() or Path(args.output_path).joinpath("results.json").is_file():
271
272
273
            eval_logger.warning(
                f"File already exists at {path}. Results will be overwritten."
            )
lintangsutawika's avatar
lintangsutawika committed
274
            output_path_file = path.joinpath("results.json")
275
276
277
278
279
280
281
282
283
284
            assert not path.is_file(), "File already exists"
        # if path json then get parent dir
        elif path.suffix in (".json", ".jsonl"):
            output_path_file = path
            path.parent.mkdir(parents=True, exist_ok=True)
            path = path.parent
        else:
            path.mkdir(parents=True, exist_ok=True)
            output_path_file = path.joinpath("results.json")

lintangsutawika's avatar
lintangsutawika committed
285
    eval_logger.info(f"Selected Tasks: {task_names}")
286
    eval_logger.info("Loading selected tasks...")
287

288
289
290
291
292
293
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        tasks=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
294
        max_batch_size=args.max_batch_size,
295
        device=args.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
296
        use_cache=args.use_cache,
297
298
299
        limit=args.limit,
        decontamination_ngrams_path=args.decontamination_ngrams_path,
        check_integrity=args.check_integrity,
300
        write_out=args.write_out,
301
        log_samples=args.log_samples,
lintangsutawika's avatar
lintangsutawika committed
302
        gen_kwargs=args.gen_kwargs,
303
        task_manager=task_manager,
Baber Abbasi's avatar
Baber Abbasi committed
304
        predict_only=args.predict_only,
305
306
307
        random_seed=args.seed[0],
        numpy_random_seed=args.seed[1],
        torch_random_seed=args.seed[2],
308
    )
309

310
    if results is not None:
311
312
        if args.log_samples:
            samples = results.pop("samples")
313
314
315
        dumped = json.dumps(
            results, indent=2, default=_handle_non_serializable, ensure_ascii=False
        )
316
317
        if args.show_config:
            print(dumped)
318

319
320
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

321
322
323
324
325
326
327
328
329
330
        # Add W&B logging
        if args.wandb_args:
            try:
                wandb_logger.post_init(results)
                wandb_logger.log_eval_result()
                if args.log_samples:
                    wandb_logger.log_eval_samples(samples)
            except Exception as e:
                eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

331
        if args.output_path:
332
            output_path_file.open("w", encoding="utf-8").write(dumped)
333

334
335
336
            if args.log_samples:
                for task_name, config in results["configs"].items():
                    output_name = "{}_{}".format(
lintangsutawika's avatar
lintangsutawika committed
337
                        re.sub("/|=", "__", args.model_args), task_name
lintangsutawika's avatar
lintangsutawika committed
338
                    )
339
                    filename = path.joinpath(f"{output_name}.jsonl")
340
                    samples_dumped = json.dumps(
341
342
343
344
                        samples[task_name],
                        indent=2,
                        default=_handle_non_serializable,
                        ensure_ascii=False,
345
                    )
346
                    filename.write_text(samples_dumped, encoding="utf-8")
lintangsutawika's avatar
lintangsutawika committed
347

348
        print(
349
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
350
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
351
        )
352
        print(make_table(results))
lintangsutawika's avatar
lintangsutawika committed
353
        if "groups" in results:
354
            print(make_table(results, "groups"))
Jason Phang's avatar
lib  
Jason Phang committed
355

356
357
358
359
        if args.wandb_args:
            # Tear down wandb run once all the logging is done.
            wandb_logger.run.finish()

360

Jason Phang's avatar
Jason Phang committed
361
if __name__ == "__main__":
haileyschoelkopf's avatar
haileyschoelkopf committed
362
    cli_evaluate()