__main__.py 12.9 KB
Newer Older
1
2
3
import argparse
import json
import logging
lintangsutawika's avatar
lintangsutawika committed
4
import os
lintangsutawika's avatar
lintangsutawika committed
5
import re
lintangsutawika's avatar
lintangsutawika committed
6
import sys
7
from functools import partial
8
from pathlib import Path
haileyschoelkopf's avatar
haileyschoelkopf committed
9
from typing import Union
Leo Gao's avatar
Leo Gao committed
10

11
12
import numpy as np

13
from lm_eval import evaluator, utils
14
from lm_eval.evaluator import request_caching_arg_to_dict
15
from lm_eval.logging_utils import WandbLogger
16
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
17
from lm_eval.utils import make_table
lintangsutawika's avatar
format  
lintangsutawika committed
18

19

20
def _handle_non_serializable(o):
21
    if isinstance(o, np.int64) or isinstance(o, np.int32):
22
23
24
        return int(o)
    elif isinstance(o, set):
        return list(o)
25
26
    else:
        return str(o)
Fabrizio Milo's avatar
Fabrizio Milo committed
27

28

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items != max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )

    return items


haileyschoelkopf's avatar
haileyschoelkopf committed
53
def parse_eval_args() -> argparse.Namespace:
lintangsutawika's avatar
lintangsutawika committed
54
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
Baber Abbasi's avatar
Baber Abbasi committed
55
    parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
lintangsutawika's avatar
lintangsutawika committed
56
57
    parser.add_argument(
        "--tasks",
Baber Abbasi's avatar
Baber Abbasi committed
58
        "-t",
lintangsutawika's avatar
lintangsutawika committed
59
        default=None,
60
        metavar="task1,task2",
lintangsutawika's avatar
lintangsutawika committed
61
        help="To get full list of tasks, use the command lm-eval --tasks list",
lintangsutawika's avatar
lintangsutawika committed
62
    )
63
64
    parser.add_argument(
        "--model_args",
Baber Abbasi's avatar
Baber Abbasi committed
65
        "-a",
66
        default="",
67
        help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
68
    )
lintangsutawika's avatar
lintangsutawika committed
69
    parser.add_argument(
70
        "--num_fewshot",
Baber Abbasi's avatar
Baber Abbasi committed
71
        "-f",
72
        type=int,
73
        default=None,
74
        metavar="N",
75
76
        help="Number of examples in few-shot context",
    )
77
78
    parser.add_argument(
        "--batch_size",
Baber Abbasi's avatar
Baber Abbasi committed
79
        "-b",
80
81
82
83
84
        type=str,
        default=1,
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
lintangsutawika's avatar
lintangsutawika committed
85
86
87
88
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
89
90
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
lintangsutawika's avatar
lintangsutawika committed
91
    )
92
93
94
95
    parser.add_argument(
        "--device",
        type=str,
        default=None,
96
        help="Device to use (e.g. cuda, cuda:0, cpu).",
97
98
99
    )
    parser.add_argument(
        "--output_path",
Baber Abbasi's avatar
Baber Abbasi committed
100
        "-o",
101
102
        default=None,
        type=str,
103
        metavar="DIR|DIR/file.json",
104
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
105
    )
lintangsutawika's avatar
lintangsutawika committed
106
107
    parser.add_argument(
        "--limit",
Baber Abbasi's avatar
Baber Abbasi committed
108
        "-L",
lintangsutawika's avatar
lintangsutawika committed
109
110
        type=float,
        default=None,
111
        metavar="N|0<N<1",
lintangsutawika's avatar
lintangsutawika committed
112
113
114
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
115
116
    parser.add_argument(
        "--use_cache",
Baber Abbasi's avatar
Baber Abbasi committed
117
        "-c",
118
119
        type=str,
        default=None,
120
        metavar="DIR",
121
122
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
123
124
125
126
127
128
129
    parser.add_argument(
        "--cache_requests",
        type=str,
        default=None,
        choices=["true", "refresh", "delete"],
        help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
    )
130
131
132
133
    parser.add_argument("--decontamination_ngrams_path", default=None)  # TODO: not used
    parser.add_argument(
        "--check_integrity",
        action="store_true",
134
        help="Whether to run the relevant part of the test suite for the tasks.",
135
136
137
    )
    parser.add_argument(
        "--write_out",
Baber Abbasi's avatar
Baber Abbasi committed
138
        "-w",
139
140
        action="store_true",
        default=False,
141
        help="Prints the prompt for the first few documents.",
142
143
144
    )
    parser.add_argument(
        "--log_samples",
Baber Abbasi's avatar
Baber Abbasi committed
145
        "-s",
146
147
        action="store_true",
        default=False,
148
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
149
    )
150
151
152
153
154
155
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
156
157
158
159
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
160
        metavar="DIR",
161
162
        help="Additional path to include if there are external tasks to include.",
    )
163
164
    parser.add_argument(
        "--gen_kwargs",
165
        default=None,
USVSN Sai Prashanth's avatar
USVSN Sai Prashanth committed
166
167
        help=(
            "String arguments for model generation on greedy_until tasks,"
168
            " e.g. `temperature=0,top_k=0,top_p=0`."
lintangsutawika's avatar
lintangsutawika committed
169
170
171
        ),
    )
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
172
        "--verbosity",
Baber Abbasi's avatar
Baber Abbasi committed
173
174
        "-v",
        type=str.upper,
lintangsutawika's avatar
lintangsutawika committed
175
        default="INFO",
176
177
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
178
    )
179
180
181
182
183
    parser.add_argument(
        "--wandb_args",
        default="",
        help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
    )
Baber Abbasi's avatar
Baber Abbasi committed
184
185
186
187
188
189
190
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
191
192
193
194
195
196
197
198
199
200
201
202
203
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3),
        default="0,1234,1234",  # for backward compatibility
        help=(
            "Set seed for python's random, numpy and torch.\n"
            "Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
            "or a single integer to set the same seed for all three.\n"
            "The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
            "E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all three seeds to 42."
        ),
    )
Jason Phang's avatar
Jason Phang committed
204
205
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
206

haileyschoelkopf's avatar
haileyschoelkopf committed
207
208
209
210
211
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
    if not args:
        # we allow for args to be passed externally, else we parse them ourselves
        args = parse_eval_args()

212
213
214
    if args.wandb_args:
        wandb_logger = WandbLogger(args)

215
    eval_logger = utils.eval_logger
lintangsutawika's avatar
lintangsutawika committed
216
    eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
217
    eval_logger.info(f"Verbosity set to {args.verbosity}")
haileyschoelkopf's avatar
haileyschoelkopf committed
218
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
Fabrizio Milo's avatar
Fabrizio Milo committed
219

Baber Abbasi's avatar
Baber Abbasi committed
220
221
222
223
224
    if args.predict_only:
        args.log_samples = True
    if (args.log_samples or args.predict_only) and not args.output_path:
        assert args.output_path, "Specify --output_path"

225
    initialize_tasks(args.verbosity)
226
    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
227

Leo Gao's avatar
Leo Gao committed
228
    if args.limit:
lintangsutawika's avatar
lintangsutawika committed
229
230
231
        eval_logger.warning(
            " --limit SHOULD ONLY BE USED FOR TESTING."
            "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
Fabrizio Milo's avatar
Fabrizio Milo committed
232
        )
lintangsutawika's avatar
lintangsutawika committed
233

234
    if args.tasks is None:
lintangsutawika's avatar
lintangsutawika committed
235
        eval_logger.error("Need to specify task to evaluate.")
lintangsutawika's avatar
lintangsutawika committed
236
        sys.exit()
237
    elif args.tasks == "list":
lintangsutawika's avatar
lintangsutawika committed
238
        eval_logger.info(
Lintang Sutawika's avatar
Lintang Sutawika committed
239
            "Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
lintangsutawika's avatar
lintangsutawika committed
240
        )
Jason Phang's avatar
Jason Phang committed
241
    else:
242
243
        if os.path.isdir(args.tasks):
            import glob
244

lintangsutawika's avatar
lintangsutawika committed
245
            loaded_task_list = []
246
247
            yaml_path = os.path.join(args.tasks, "*.yaml")
            for yaml_file in glob.glob(yaml_path):
lintangsutawika's avatar
lintangsutawika committed
248
                config = utils.load_yaml_config(yaml_file)
lintangsutawika's avatar
lintangsutawika committed
249
                loaded_task_list.append(config)
250
        else:
251
252
253
            task_list = args.tasks.split(",")
            task_names = task_manager.match_tasks(task_list)
            for task in [task for task in task_list if task not in task_names]:
254
                if os.path.isfile(task):
lintangsutawika's avatar
lintangsutawika committed
255
                    config = utils.load_yaml_config(task)
lintangsutawika's avatar
lintangsutawika committed
256
                    loaded_task_list.append(config)
257
            task_missing = [
258
                task for task in task_list if task not in task_names and "*" not in task
259
            ]  # we don't want errors if a wildcard ("*") task name was used
lintangsutawika's avatar
lintangsutawika committed
260

baberabb's avatar
baberabb committed
261
262
263
264
            if task_missing:
                missing = ", ".join(task_missing)
                eval_logger.error(
                    f"Tasks were not found: {missing}\n"
lintangsutawika's avatar
lintangsutawika committed
265
                    f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
baberabb's avatar
baberabb committed
266
267
                )
                raise ValueError(
268
                    f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
baberabb's avatar
baberabb committed
269
                )
lintangsutawika's avatar
lintangsutawika committed
270

271
272
    if args.output_path:
        path = Path(args.output_path)
Lintang Sutawika's avatar
Lintang Sutawika committed
273
        # check if file or 'dir/results.json' exists
baberabb's avatar
baberabb committed
274
        if path.is_file() or Path(args.output_path).joinpath("results.json").is_file():
275
276
277
            eval_logger.warning(
                f"File already exists at {path}. Results will be overwritten."
            )
lintangsutawika's avatar
lintangsutawika committed
278
            output_path_file = path.joinpath("results.json")
279
280
281
282
283
284
285
286
287
288
            assert not path.is_file(), "File already exists"
        # if path json then get parent dir
        elif path.suffix in (".json", ".jsonl"):
            output_path_file = path
            path.parent.mkdir(parents=True, exist_ok=True)
            path = path.parent
        else:
            path.mkdir(parents=True, exist_ok=True)
            output_path_file = path.joinpath("results.json")

lintangsutawika's avatar
lintangsutawika committed
289
    eval_logger.info(f"Selected Tasks: {task_names}")
lintangsutawika's avatar
lintangsutawika committed
290
291
    eval_logger.info("Loading selected tasks...")

292
293
294
295
    request_caching_args = request_caching_arg_to_dict(
        cache_requests=args.cache_requests
    )

296
297
298
299
300
301
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        tasks=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
302
        max_batch_size=args.max_batch_size,
303
        device=args.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
304
        use_cache=args.use_cache,
305
306
307
        limit=args.limit,
        decontamination_ngrams_path=args.decontamination_ngrams_path,
        check_integrity=args.check_integrity,
308
        write_out=args.write_out,
309
        log_samples=args.log_samples,
lintangsutawika's avatar
lintangsutawika committed
310
        gen_kwargs=args.gen_kwargs,
311
        task_manager=task_manager,
Baber Abbasi's avatar
Baber Abbasi committed
312
        predict_only=args.predict_only,
313
        **request_caching_args,
314
315
316
        random_seed=args.seed[0],
        numpy_random_seed=args.seed[1],
        torch_random_seed=args.seed[2],
317
    )
318

319
    if results is not None:
320
321
        if args.log_samples:
            samples = results.pop("samples")
322
323
324
        dumped = json.dumps(
            results, indent=2, default=_handle_non_serializable, ensure_ascii=False
        )
325
326
        if args.show_config:
            print(dumped)
327

328
329
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

330
331
332
333
334
335
336
337
338
339
        # Add W&B logging
        if args.wandb_args:
            try:
                wandb_logger.post_init(results)
                wandb_logger.log_eval_result()
                if args.log_samples:
                    wandb_logger.log_eval_samples(samples)
            except Exception as e:
                eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

340
        if args.output_path:
341
            output_path_file.open("w", encoding="utf-8").write(dumped)
342

343
344
345
            if args.log_samples:
                for task_name, config in results["configs"].items():
                    output_name = "{}_{}".format(
lintangsutawika's avatar
lintangsutawika committed
346
                        re.sub("/|=", "__", args.model_args), task_name
lintangsutawika's avatar
lintangsutawika committed
347
                    )
348
                    filename = path.joinpath(f"{output_name}.jsonl")
349
                    samples_dumped = json.dumps(
350
351
352
353
                        samples[task_name],
                        indent=2,
                        default=_handle_non_serializable,
                        ensure_ascii=False,
354
                    )
355
                    filename.write_text(samples_dumped, encoding="utf-8")
lintangsutawika's avatar
lintangsutawika committed
356

357
        print(
358
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
359
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
360
        )
361
        print(make_table(results))
lintangsutawika's avatar
lintangsutawika committed
362
        if "groups" in results:
363
            print(make_table(results, "groups"))
Jason Phang's avatar
lib  
Jason Phang committed
364

365
366
367
368
        if args.wandb_args:
            # Tear down wandb run once all the logging is done.
            wandb_logger.run.finish()

369

Jason Phang's avatar
Jason Phang committed
370
if __name__ == "__main__":
haileyschoelkopf's avatar
haileyschoelkopf committed
371
    cli_evaluate()