__main__.py 11.6 KB
Newer Older
1
2
3
import argparse
import json
import logging
lintangsutawika's avatar
lintangsutawika committed
4
import os
lintangsutawika's avatar
lintangsutawika committed
5
import re
lintangsutawika's avatar
lintangsutawika committed
6
import sys
7
from functools import partial
8
from pathlib import Path
haileyschoelkopf's avatar
haileyschoelkopf committed
9
from typing import Union
Leo Gao's avatar
Leo Gao committed
10

11
12
import numpy as np

13
from lm_eval import evaluator, utils
14
from lm_eval.tasks import TaskManager, include_path, initialize_tasks
15
from lm_eval.utils import make_table
lintangsutawika's avatar
format  
lintangsutawika committed
16

17

18
def _handle_non_serializable(o):
19
    if isinstance(o, np.int64) or isinstance(o, np.int32):
20
21
22
        return int(o)
    elif isinstance(o, set):
        return list(o)
23
24
    else:
        return str(o)
Fabrizio Milo's avatar
Fabrizio Milo committed
25

26

27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items != max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )

    return items


haileyschoelkopf's avatar
haileyschoelkopf committed
51
def parse_eval_args() -> argparse.Namespace:
lintangsutawika's avatar
lintangsutawika committed
52
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
Baber Abbasi's avatar
Baber Abbasi committed
53
    parser.add_argument("--model", "-m", default="hf", help="Name of model e.g. `hf`")
lintangsutawika's avatar
lintangsutawika committed
54
55
    parser.add_argument(
        "--tasks",
Baber Abbasi's avatar
Baber Abbasi committed
56
        "-t",
lintangsutawika's avatar
lintangsutawika committed
57
        default=None,
58
        metavar="task1,task2",
lintangsutawika's avatar
lintangsutawika committed
59
        help="To get full list of tasks, use the command lm-eval --tasks list",
lintangsutawika's avatar
lintangsutawika committed
60
    )
61
62
    parser.add_argument(
        "--model_args",
Baber Abbasi's avatar
Baber Abbasi committed
63
        "-a",
64
        default="",
65
        help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
66
    )
lintangsutawika's avatar
lintangsutawika committed
67
    parser.add_argument(
68
        "--num_fewshot",
Baber Abbasi's avatar
Baber Abbasi committed
69
        "-f",
70
        type=int,
71
        default=None,
72
        metavar="N",
73
74
        help="Number of examples in few-shot context",
    )
75
76
    parser.add_argument(
        "--batch_size",
Baber Abbasi's avatar
Baber Abbasi committed
77
        "-b",
78
79
80
81
82
        type=str,
        default=1,
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
lintangsutawika's avatar
lintangsutawika committed
83
84
85
86
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
87
88
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
lintangsutawika's avatar
lintangsutawika committed
89
    )
90
91
92
93
    parser.add_argument(
        "--device",
        type=str,
        default=None,
94
        help="Device to use (e.g. cuda, cuda:0, cpu).",
95
96
97
    )
    parser.add_argument(
        "--output_path",
Baber Abbasi's avatar
Baber Abbasi committed
98
        "-o",
99
100
        default=None,
        type=str,
101
        metavar="DIR|DIR/file.json",
102
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
103
    )
lintangsutawika's avatar
lintangsutawika committed
104
105
    parser.add_argument(
        "--limit",
Baber Abbasi's avatar
Baber Abbasi committed
106
        "-L",
lintangsutawika's avatar
lintangsutawika committed
107
108
        type=float,
        default=None,
109
        metavar="N|0<N<1",
lintangsutawika's avatar
lintangsutawika committed
110
111
112
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
113
114
    parser.add_argument(
        "--use_cache",
Baber Abbasi's avatar
Baber Abbasi committed
115
        "-c",
116
117
        type=str,
        default=None,
118
        metavar="DIR",
119
120
121
122
123
124
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
    parser.add_argument("--decontamination_ngrams_path", default=None)  # TODO: not used
    parser.add_argument(
        "--check_integrity",
        action="store_true",
125
        help="Whether to run the relevant part of the test suite for the tasks.",
126
127
128
    )
    parser.add_argument(
        "--write_out",
Baber Abbasi's avatar
Baber Abbasi committed
129
        "-w",
130
131
        action="store_true",
        default=False,
132
        help="Prints the prompt for the first few documents.",
133
134
135
    )
    parser.add_argument(
        "--log_samples",
Baber Abbasi's avatar
Baber Abbasi committed
136
        "-s",
137
138
        action="store_true",
        default=False,
139
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
140
    )
141
142
143
144
145
146
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
147
148
149
150
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
151
        metavar="DIR",
152
153
        help="Additional path to include if there are external tasks to include.",
    )
154
155
    parser.add_argument(
        "--gen_kwargs",
156
        default=None,
USVSN Sai Prashanth's avatar
USVSN Sai Prashanth committed
157
158
        help=(
            "String arguments for model generation on greedy_until tasks,"
159
            " e.g. `temperature=0,top_k=0,top_p=0`."
lintangsutawika's avatar
lintangsutawika committed
160
161
162
        ),
    )
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
163
        "--verbosity",
Baber Abbasi's avatar
Baber Abbasi committed
164
165
        "-v",
        type=str.upper,
lintangsutawika's avatar
lintangsutawika committed
166
        default="INFO",
167
168
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
169
    )
Baber Abbasi's avatar
Baber Abbasi committed
170
171
172
173
174
175
176
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
177
178
179
180
181
182
183
184
185
186
187
188
189
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3),
        default="0,1234,1234",  # for backward compatibility
        help=(
            "Set seed for python's random, numpy and torch.\n"
            "Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
            "or a single integer to set the same seed for all three.\n"
            "The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
            "E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all three seeds to 42."
        ),
    )
Jason Phang's avatar
Jason Phang committed
190
191
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
192

haileyschoelkopf's avatar
haileyschoelkopf committed
193
194
195
196
197
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
    if not args:
        # we allow for args to be passed externally, else we parse them ourselves
        args = parse_eval_args()

198
    eval_logger = utils.eval_logger
lintangsutawika's avatar
lintangsutawika committed
199
    eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
200
    eval_logger.info(f"Verbosity set to {args.verbosity}")
haileyschoelkopf's avatar
haileyschoelkopf committed
201
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
Fabrizio Milo's avatar
Fabrizio Milo committed
202

Baber Abbasi's avatar
Baber Abbasi committed
203
204
205
206
207
    if args.predict_only:
        args.log_samples = True
    if (args.log_samples or args.predict_only) and not args.output_path:
        assert args.output_path, "Specify --output_path"

208
    initialize_tasks(args.verbosity)
209
    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
210

Leo Gao's avatar
Leo Gao committed
211
    if args.limit:
lintangsutawika's avatar
lintangsutawika committed
212
213
214
        eval_logger.warning(
            " --limit SHOULD ONLY BE USED FOR TESTING."
            "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
Fabrizio Milo's avatar
Fabrizio Milo committed
215
        )
lintangsutawika's avatar
lintangsutawika committed
216

217
    if args.tasks is None:
lintangsutawika's avatar
lintangsutawika committed
218
        eval_logger.error("Need to specify task to evaluate.")
lintangsutawika's avatar
lintangsutawika committed
219
        sys.exit()
220
    elif args.tasks == "list":
lintangsutawika's avatar
lintangsutawika committed
221
        eval_logger.info(
Lintang Sutawika's avatar
Lintang Sutawika committed
222
            "Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
lintangsutawika's avatar
lintangsutawika committed
223
        )
Jason Phang's avatar
Jason Phang committed
224
    else:
225
226
        if os.path.isdir(args.tasks):
            import glob
227

lintangsutawika's avatar
lintangsutawika committed
228
            loaded_task_list = []
229
230
            yaml_path = os.path.join(args.tasks, "*.yaml")
            for yaml_file in glob.glob(yaml_path):
lintangsutawika's avatar
lintangsutawika committed
231
                config = utils.load_yaml_config(yaml_file)
lintangsutawika's avatar
lintangsutawika committed
232
                loaded_task_list.append(config)
233
        else:
234
235
236
            task_list = args.tasks.split(",")
            task_names = task_manager.match_tasks(task_list)
            for task in [task for task in task_list if task not in task_names]:
237
                if os.path.isfile(task):
lintangsutawika's avatar
lintangsutawika committed
238
                    config = utils.load_yaml_config(task)
lintangsutawika's avatar
lintangsutawika committed
239
                    loaded_task_list.append(config)
240
            task_missing = [
241
                task for task in task_list if task not in task_names and "*" not in task
242
            ]  # we don't want errors if a wildcard ("*") task name was used
lintangsutawika's avatar
lintangsutawika committed
243

baberabb's avatar
baberabb committed
244
245
246
247
            if task_missing:
                missing = ", ".join(task_missing)
                eval_logger.error(
                    f"Tasks were not found: {missing}\n"
lintangsutawika's avatar
lintangsutawika committed
248
                    f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
baberabb's avatar
baberabb committed
249
250
                )
                raise ValueError(
251
                    f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
baberabb's avatar
baberabb committed
252
                )
lintangsutawika's avatar
lintangsutawika committed
253

254
255
    if args.output_path:
        path = Path(args.output_path)
Lintang Sutawika's avatar
Lintang Sutawika committed
256
        # check if file or 'dir/results.json' exists
baberabb's avatar
baberabb committed
257
        if path.is_file() or Path(args.output_path).joinpath("results.json").is_file():
258
259
260
            eval_logger.warning(
                f"File already exists at {path}. Results will be overwritten."
            )
lintangsutawika's avatar
lintangsutawika committed
261
            output_path_file = path.joinpath("results.json")
262
263
264
265
266
267
268
269
270
271
            assert not path.is_file(), "File already exists"
        # if path json then get parent dir
        elif path.suffix in (".json", ".jsonl"):
            output_path_file = path
            path.parent.mkdir(parents=True, exist_ok=True)
            path = path.parent
        else:
            path.mkdir(parents=True, exist_ok=True)
            output_path_file = path.joinpath("results.json")

lintangsutawika's avatar
lintangsutawika committed
272
    eval_logger.info(f"Selected Tasks: {task_names}")
lintangsutawika's avatar
lintangsutawika committed
273
274
    eval_logger.info("Loading selected tasks...")

275
276
277
278
279
280
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        tasks=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
281
        max_batch_size=args.max_batch_size,
282
        device=args.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
283
        use_cache=args.use_cache,
284
285
286
        limit=args.limit,
        decontamination_ngrams_path=args.decontamination_ngrams_path,
        check_integrity=args.check_integrity,
287
        write_out=args.write_out,
288
        log_samples=args.log_samples,
lintangsutawika's avatar
lintangsutawika committed
289
        gen_kwargs=args.gen_kwargs,
290
        task_manager=task_manager,
Baber Abbasi's avatar
Baber Abbasi committed
291
        predict_only=args.predict_only,
292
293
294
        random_seed=args.seed[0],
        numpy_random_seed=args.seed[1],
        torch_random_seed=args.seed[2],
295
    )
296

297
    if results is not None:
298
299
        if args.log_samples:
            samples = results.pop("samples")
300
301
302
        dumped = json.dumps(
            results, indent=2, default=_handle_non_serializable, ensure_ascii=False
        )
303
304
        if args.show_config:
            print(dumped)
305

306
307
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

308
        if args.output_path:
309
            output_path_file.open("w", encoding="utf-8").write(dumped)
310

311
312
313
            if args.log_samples:
                for task_name, config in results["configs"].items():
                    output_name = "{}_{}".format(
lintangsutawika's avatar
lintangsutawika committed
314
                        re.sub("/|=", "__", args.model_args), task_name
lintangsutawika's avatar
lintangsutawika committed
315
                    )
316
                    filename = path.joinpath(f"{output_name}.jsonl")
317
                    samples_dumped = json.dumps(
318
319
320
321
                        samples[task_name],
                        indent=2,
                        default=_handle_non_serializable,
                        ensure_ascii=False,
322
                    )
323
                    filename.write_text(samples_dumped, encoding="utf-8")
lintangsutawika's avatar
lintangsutawika committed
324

325
        print(
326
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
327
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
328
        )
329
        print(make_table(results))
lintangsutawika's avatar
lintangsutawika committed
330
        if "groups" in results:
331
            print(make_table(results, "groups"))
Jason Phang's avatar
lib  
Jason Phang committed
332

333

Jason Phang's avatar
Jason Phang committed
334
if __name__ == "__main__":
haileyschoelkopf's avatar
haileyschoelkopf committed
335
    cli_evaluate()