__main__.py 14 KB
Newer Older
1
2
3
import argparse
import json
import logging
lintangsutawika's avatar
lintangsutawika committed
4
import os
5
import sys
6
from functools import partial
haileyschoelkopf's avatar
haileyschoelkopf committed
7
from typing import Union
Leo Gao's avatar
Leo Gao committed
8

9
from lm_eval import evaluator, utils
10
from lm_eval.evaluator import request_caching_arg_to_dict
11
from lm_eval.logging import EvaluationTracker, WandbLogger
12
from lm_eval.tasks import TaskManager
13
from lm_eval.utils import handle_non_serializable, make_table, simple_parse_args_string
Fabrizio Milo's avatar
Fabrizio Milo committed
14

15

16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def _int_or_none_list_arg_type(max_len: int, value: str, split_char: str = ","):
    def parse_value(item):
        item = item.strip().lower()
        if item == "none":
            return None
        try:
            return int(item)
        except ValueError:
            raise argparse.ArgumentTypeError(f"{item} is not an integer or None")

    items = [parse_value(v) for v in value.split(split_char)]
    num_items = len(items)

    if num_items == 1:
        # Makes downstream handling the same for single and multiple values
        items = items * max_len
    elif num_items != max_len:
        raise argparse.ArgumentTypeError(
            f"Argument requires {max_len} integers or None, separated by '{split_char}'"
        )

    return items


40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def check_argument_types(parser: argparse.ArgumentParser):
    """
    Check to make sure all CLI args are typed, raises error if not
    """
    for action in parser._actions:
        if action.dest != "help" and not action.const:
            if action.type is None:
                raise ValueError(
                    f"Argument '{action.dest}' doesn't have a type specified."
                )
            else:
                continue


def setup_parser() -> argparse.ArgumentParser:
lintangsutawika's avatar
lintangsutawika committed
55
    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
56
57
58
    parser.add_argument(
        "--model", "-m", type=str, default="hf", help="Name of model e.g. `hf`"
    )
lintangsutawika's avatar
lintangsutawika committed
59
60
    parser.add_argument(
        "--tasks",
Baber Abbasi's avatar
Baber Abbasi committed
61
        "-t",
lintangsutawika's avatar
lintangsutawika committed
62
        default=None,
63
        type=str,
64
        metavar="task1,task2",
lintangsutawika's avatar
lintangsutawika committed
65
        help="To get full list of tasks, use the command lm-eval --tasks list",
lintangsutawika's avatar
lintangsutawika committed
66
    )
67
68
    parser.add_argument(
        "--model_args",
Baber Abbasi's avatar
Baber Abbasi committed
69
        "-a",
70
        default="",
71
        type=str,
72
        help="Comma separated string arguments for model, e.g. `pretrained=EleutherAI/pythia-160m,dtype=float32`",
73
    )
lintangsutawika's avatar
lintangsutawika committed
74
    parser.add_argument(
75
        "--num_fewshot",
Baber Abbasi's avatar
Baber Abbasi committed
76
        "-f",
77
        type=int,
78
        default=None,
79
        metavar="N",
80
81
        help="Number of examples in few-shot context",
    )
82
83
    parser.add_argument(
        "--batch_size",
Baber Abbasi's avatar
Baber Abbasi committed
84
        "-b",
85
86
87
88
89
        type=str,
        default=1,
        metavar="auto|auto:N|N",
        help="Acceptable values are 'auto', 'auto:N' or N, where N is an integer. Default 1.",
    )
lintangsutawika's avatar
lintangsutawika committed
90
91
92
93
    parser.add_argument(
        "--max_batch_size",
        type=int,
        default=None,
94
95
        metavar="N",
        help="Maximal batch size to try with --batch_size auto.",
lintangsutawika's avatar
lintangsutawika committed
96
    )
97
98
99
100
    parser.add_argument(
        "--device",
        type=str,
        default=None,
101
        help="Device to use (e.g. cuda, cuda:0, cpu).",
102
103
104
    )
    parser.add_argument(
        "--output_path",
Baber Abbasi's avatar
Baber Abbasi committed
105
        "-o",
106
107
        default=None,
        type=str,
108
        metavar="DIR|DIR/file.json",
109
        help="The path to the output file where the result metrics will be saved. If the path is a directory and log_samples is true, the results will be saved in the directory. Else the parent directory will be used.",
110
    )
lintangsutawika's avatar
lintangsutawika committed
111
112
    parser.add_argument(
        "--limit",
Baber Abbasi's avatar
Baber Abbasi committed
113
        "-L",
lintangsutawika's avatar
lintangsutawika committed
114
115
        type=float,
        default=None,
116
        metavar="N|0<N<1",
lintangsutawika's avatar
lintangsutawika committed
117
118
119
        help="Limit the number of examples per task. "
        "If <1, limit is a percentage of the total number of examples.",
    )
120
121
    parser.add_argument(
        "--use_cache",
Baber Abbasi's avatar
Baber Abbasi committed
122
        "-c",
123
124
        type=str,
        default=None,
125
        metavar="DIR",
126
127
        help="A path to a sqlite db file for caching model responses. `None` if not caching.",
    )
128
129
130
131
132
133
134
    parser.add_argument(
        "--cache_requests",
        type=str,
        default=None,
        choices=["true", "refresh", "delete"],
        help="Speed up evaluation by caching the building of dataset requests. `None` if not caching.",
    )
135
136
137
    parser.add_argument(
        "--check_integrity",
        action="store_true",
138
        help="Whether to run the relevant part of the test suite for the tasks.",
139
140
141
    )
    parser.add_argument(
        "--write_out",
Baber Abbasi's avatar
Baber Abbasi committed
142
        "-w",
143
144
        action="store_true",
        default=False,
145
        help="Prints the prompt for the first few documents.",
146
147
148
    )
    parser.add_argument(
        "--log_samples",
Baber Abbasi's avatar
Baber Abbasi committed
149
        "-s",
150
151
        action="store_true",
        default=False,
152
        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis. Use with --output_path.",
153
    )
154
155
156
157
158
159
    parser.add_argument(
        "--show_config",
        action="store_true",
        default=False,
        help="If True, shows the the full config of all tasks at the end of the evaluation.",
    )
160
161
162
163
    parser.add_argument(
        "--include_path",
        type=str,
        default=None,
164
        metavar="DIR",
165
166
        help="Additional path to include if there are external tasks to include.",
    )
167
168
    parser.add_argument(
        "--gen_kwargs",
169
        type=str,
170
        default=None,
USVSN Sai Prashanth's avatar
USVSN Sai Prashanth committed
171
172
        help=(
            "String arguments for model generation on greedy_until tasks,"
173
            " e.g. `temperature=0,top_k=0,top_p=0`."
lintangsutawika's avatar
lintangsutawika committed
174
175
176
        ),
    )
    parser.add_argument(
lintangsutawika's avatar
lintangsutawika committed
177
        "--verbosity",
Baber Abbasi's avatar
Baber Abbasi committed
178
179
        "-v",
        type=str.upper,
lintangsutawika's avatar
lintangsutawika committed
180
        default="INFO",
181
182
        metavar="CRITICAL|ERROR|WARNING|INFO|DEBUG",
        help="Controls the reported logging error level. Set to DEBUG when testing + adding new task configurations for comprehensive log output.",
183
    )
184
185
    parser.add_argument(
        "--wandb_args",
186
        type=str,
187
188
189
        default="",
        help="Comma separated string arguments passed to wandb.init, e.g. `project=lm-eval,job_type=eval",
    )
190
191
192
193
194
195
    parser.add_argument(
        "--hf_hub_log_args",
        type=str,
        default="",
        help="Comma separated string arguments passed to Hugging Face Hub's log function, e.g. `hub_results_org=EleutherAI,hub_repo_name=lm-eval-results`",
    )
Baber Abbasi's avatar
Baber Abbasi committed
196
197
198
199
200
201
202
    parser.add_argument(
        "--predict_only",
        "-x",
        action="store_true",
        default=False,
        help="Use with --log_samples. Only model outputs will be saved and metrics will not be evaluated.",
    )
203
204
205
206
207
208
209
210
211
212
213
214
215
    parser.add_argument(
        "--seed",
        type=partial(_int_or_none_list_arg_type, 3),
        default="0,1234,1234",  # for backward compatibility
        help=(
            "Set seed for python's random, numpy and torch.\n"
            "Accepts a comma-separated list of 3 values for python's random, numpy, and torch seeds, respectively, "
            "or a single integer to set the same seed for all three.\n"
            "The values are either an integer or 'None' to not set the seed. Default is `0,1234,1234` (for backward compatibility).\n"
            "E.g. `--seed 0,None,8` sets `random.seed(0)` and `torch.manual_seed(8)`. Here numpy's seed is not set since the second value is `None`.\n"
            "E.g, `--seed 42` sets all three seeds to 42."
        ),
    )
216
217
    parser.add_argument(
        "--trust_remote_code",
218
        action="store_true",
219
220
        help="Sets trust_remote_code to True to execute code to create HF Datasets from the Hub",
    )
221
222
223
224
225
    return parser


def parse_eval_args(parser: argparse.ArgumentParser) -> argparse.Namespace:
    check_argument_types(parser)
Jason Phang's avatar
Jason Phang committed
226
227
    return parser.parse_args()

Fabrizio Milo's avatar
Fabrizio Milo committed
228

haileyschoelkopf's avatar
haileyschoelkopf committed
229
230
231
def cli_evaluate(args: Union[argparse.Namespace, None] = None) -> None:
    if not args:
        # we allow for args to be passed externally, else we parse them ourselves
232
233
        parser = setup_parser()
        args = parse_eval_args(parser)
haileyschoelkopf's avatar
haileyschoelkopf committed
234

235
    if args.wandb_args:
236
        wandb_logger = WandbLogger(**simple_parse_args_string(args.wandb_args))
237

238
    eval_logger = utils.eval_logger
lintangsutawika's avatar
lintangsutawika committed
239
    eval_logger.setLevel(getattr(logging, f"{args.verbosity}"))
240
    eval_logger.info(f"Verbosity set to {args.verbosity}")
haileyschoelkopf's avatar
haileyschoelkopf committed
241
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
Fabrizio Milo's avatar
Fabrizio Milo committed
242

243
244
245
246
247
248
249
250
251
    # update the evaluation tracker args with the output path and the HF token
    args.hf_hub_log_args = f"output_path={args.output_path},token={os.environ.get('HF_TOKEN')},{args.hf_hub_log_args}"
    evaluation_tracker_args = simple_parse_args_string(args.hf_hub_log_args)
    evaluation_tracker = EvaluationTracker(**evaluation_tracker_args)
    evaluation_tracker.general_config_tracker.log_experiment_args(
        model_source=args.model,
        model_args=args.model_args,
    )

Baber Abbasi's avatar
Baber Abbasi committed
252
253
254
    if args.predict_only:
        args.log_samples = True
    if (args.log_samples or args.predict_only) and not args.output_path:
255
256
257
        raise ValueError(
            "Specify --output_path if providing --log_samples or --predict_only"
        )
Baber Abbasi's avatar
Baber Abbasi committed
258

259
260
    if args.include_path is not None:
        eval_logger.info(f"Including path: {args.include_path}")
261
    task_manager = TaskManager(args.verbosity, include_path=args.include_path)
Fabrizio Milo's avatar
Fabrizio Milo committed
262

263
    if (
KonradSzafer's avatar
KonradSzafer committed
264
265
266
        "push_results_to_hub" in evaluation_tracker_args
        or "push_samples_to_hub" in evaluation_tracker_args
    ) and "hub_results_org" not in evaluation_tracker_args:
267
268
269
        raise ValueError(
            "If push_results_to_hub or push_samples_to_hub is set, results_org must be specified."
        )
KonradSzafer's avatar
KonradSzafer committed
270
    if "push_samples_to_hub" in evaluation_tracker_args and not args.log_samples:
271
272
273
274
        eval_logger.warning(
            "Pushing samples to the Hub requires --log_samples to be set. Samples will not be pushed to the Hub."
        )

Leo Gao's avatar
Leo Gao committed
275
    if args.limit:
lintangsutawika's avatar
lintangsutawika committed
276
277
278
        eval_logger.warning(
            " --limit SHOULD ONLY BE USED FOR TESTING."
            "REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
Fabrizio Milo's avatar
Fabrizio Milo committed
279
        )
lintangsutawika's avatar
lintangsutawika committed
280

281
    if args.tasks is None:
282
283
        eval_logger.error("Need to specify task to evaluate.")
        sys.exit()
284
    elif args.tasks == "list":
lintangsutawika's avatar
lintangsutawika committed
285
        eval_logger.info(
Lintang Sutawika's avatar
Lintang Sutawika committed
286
            "Available Tasks:\n - {}".format("\n - ".join(task_manager.all_tasks))
lintangsutawika's avatar
lintangsutawika committed
287
        )
Lintang Sutawika's avatar
Lintang Sutawika committed
288
        sys.exit()
Jason Phang's avatar
Jason Phang committed
289
    else:
290
291
        if os.path.isdir(args.tasks):
            import glob
292
293

            task_names = []
294
295
            yaml_path = os.path.join(args.tasks, "*.yaml")
            for yaml_file in glob.glob(yaml_path):
lintangsutawika's avatar
lintangsutawika committed
296
                config = utils.load_yaml_config(yaml_file)
297
298
                task_names.append(config)
        else:
299
300
301
            task_list = args.tasks.split(",")
            task_names = task_manager.match_tasks(task_list)
            for task in [task for task in task_list if task not in task_names]:
302
                if os.path.isfile(task):
lintangsutawika's avatar
lintangsutawika committed
303
                    config = utils.load_yaml_config(task)
304
                    task_names.append(config)
305
            task_missing = [
306
                task for task in task_list if task not in task_names and "*" not in task
307
            ]  # we don't want errors if a wildcard ("*") task name was used
lintangsutawika's avatar
lintangsutawika committed
308

baberabb's avatar
baberabb committed
309
310
311
312
            if task_missing:
                missing = ", ".join(task_missing)
                eval_logger.error(
                    f"Tasks were not found: {missing}\n"
lintangsutawika's avatar
lintangsutawika committed
313
                    f"{utils.SPACING}Try `lm-eval --tasks list` for list of available tasks",
baberabb's avatar
baberabb committed
314
315
                )
                raise ValueError(
316
                    f"Tasks not found: {missing}. Try `lm-eval --tasks list` for list of available tasks, or '--verbosity DEBUG' to troubleshoot task registration issues."
baberabb's avatar
baberabb committed
317
                )
lintangsutawika's avatar
lintangsutawika committed
318

319
320
    # Respect user's value passed in via CLI, otherwise default to True and add to comma-separated model args
    if args.trust_remote_code:
321
        os.environ["HF_DATASETS_TRUST_REMOTE_CODE"] = str(args.trust_remote_code)
322
323
324
325
326
        args.model_args = (
            args.model_args
            + f",trust_remote_code={os.environ['HF_DATASETS_TRUST_REMOTE_CODE']}"
        )

lintangsutawika's avatar
lintangsutawika committed
327
    eval_logger.info(f"Selected Tasks: {task_names}")
328

329
330
331
332
    request_caching_args = request_caching_arg_to_dict(
        cache_requests=args.cache_requests
    )

333
334
335
336
337
338
    results = evaluator.simple_evaluate(
        model=args.model,
        model_args=args.model_args,
        tasks=task_names,
        num_fewshot=args.num_fewshot,
        batch_size=args.batch_size,
339
        max_batch_size=args.max_batch_size,
340
        device=args.device,
haileyschoelkopf's avatar
haileyschoelkopf committed
341
        use_cache=args.use_cache,
342
343
        limit=args.limit,
        check_integrity=args.check_integrity,
344
        write_out=args.write_out,
345
        log_samples=args.log_samples,
lintangsutawika's avatar
lintangsutawika committed
346
        gen_kwargs=args.gen_kwargs,
347
        task_manager=task_manager,
348
        verbosity=args.verbosity,
Baber Abbasi's avatar
Baber Abbasi committed
349
        predict_only=args.predict_only,
350
351
352
        random_seed=args.seed[0],
        numpy_random_seed=args.seed[1],
        torch_random_seed=args.seed[2],
353
        **request_caching_args,
354
    )
355

356
    if results is not None:
357
358
        if args.log_samples:
            samples = results.pop("samples")
359
        dumped = json.dumps(
360
            results, indent=2, default=handle_non_serializable, ensure_ascii=False
361
        )
362
363
        if args.show_config:
            print(dumped)
364

365
366
        batch_sizes = ",".join(map(str, results["config"]["batch_sizes"]))

367
368
369
370
371
372
373
374
375
376
        # Add W&B logging
        if args.wandb_args:
            try:
                wandb_logger.post_init(results)
                wandb_logger.log_eval_result()
                if args.log_samples:
                    wandb_logger.log_eval_samples(samples)
            except Exception as e:
                eval_logger.info(f"Logging to Weights and Biases failed due to {e}")

KonradSzafer's avatar
KonradSzafer committed
377
378
379
        evaluation_tracker.save_results_aggregated(
            results=results, samples=samples if args.log_samples else None
        )
380
381
382
383
384
385

        if args.log_samples:
            for task_name, config in results["configs"].items():
                evaluation_tracker.save_results_samples(
                    task_name=task_name, samples=samples[task_name]
                )
lintangsutawika's avatar
lintangsutawika committed
386

387
        print(
388
            f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, "
389
            f"batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
390
        )
391
        print(make_table(results))
lintangsutawika's avatar
lintangsutawika committed
392
        if "groups" in results:
393
            print(make_table(results, "groups"))
Jason Phang's avatar
lib  
Jason Phang committed
394

395
396
397
398
        if args.wandb_args:
            # Tear down wandb run once all the logging is done.
            wandb_logger.run.finish()

399

Jason Phang's avatar
Jason Phang committed
400
if __name__ == "__main__":
haileyschoelkopf's avatar
haileyschoelkopf committed
401
    cli_evaluate()