evaluator.py 26.7 KB
Newer Older
Baber Abbasi's avatar
Baber Abbasi committed
1
import itertools
2
import json
3
import logging
Baber Abbasi's avatar
Baber Abbasi committed
4
import random
5
import time
6
7
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
Baber Abbasi's avatar
Baber Abbasi committed
8

9
import numpy as np
Baber Abbasi's avatar
Baber Abbasi committed
10
import torch
lintangsutawika's avatar
lintangsutawika committed
11

lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
13
import lm_eval.api.registry
14
import lm_eval.api.task
Baber Abbasi's avatar
Baber Abbasi committed
15
import lm_eval.models
16
from lm_eval.caching.cache import delete_cache
17
18
19
20
from lm_eval.evaluator_utils import (
    consolidate_results,
    get_sample_size,
    get_task_list,
21
    prepare_print_tasks,
22
23
24
    print_writeout,
    run_task_tests,
)
25
from lm_eval.logging.utils import add_env_info, get_git_commit_hash
26
27
28
29
30
31
from lm_eval.tasks import (
    ConfigurableGroup,
    ConfigurableTask,
    TaskManager,
    get_task_dict,
)
32
33
34
35
36
37
38
from lm_eval.utils import (
    eval_logger,
    handle_non_serializable,
    hash_string,
    positional_deprecated,
    simple_parse_args_string,
)
39

Fabrizio Milo's avatar
Fabrizio Milo committed
40

41
42
43
44
45
if TYPE_CHECKING:
    from lm_eval.api.model import LM
    from lm_eval.tasks import Task


46
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
47
48
def simple_evaluate(
    model,
49
50
    model_args: Optional[Union[str, dict]] = None,
    tasks: Optional[List[Union[str, dict, object]]] = None,
Baber Abbasi's avatar
Baber Abbasi committed
51
52
53
54
55
    num_fewshot: Optional[int] = None,
    batch_size: Optional[int] = None,
    max_batch_size: Optional[int] = None,
    device: Optional[str] = None,
    use_cache: Optional[str] = None,
56
57
58
    cache_requests: bool = False,
    rewrite_requests_cache: bool = False,
    delete_requests_cache: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
59
    limit: Optional[Union[int, float]] = None,
Ethan Smith's avatar
Ethan Smith committed
60
61
62
63
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
    write_out: bool = False,
    log_samples: bool = True,
64
65
    gen_kwargs: Optional[str] = None,
    task_manager: Optional[TaskManager] = None,
66
    verbosity: str = "INFO",
Baber Abbasi's avatar
Baber Abbasi committed
67
    predict_only: bool = False,
68
69
70
    random_seed: int = 0,
    numpy_random_seed: int = 1234,
    torch_random_seed: int = 1234,
71
    fewshot_random_seed: int = 1234,
Fabrizio Milo's avatar
Fabrizio Milo committed
72
):
73
    """Instantiate and evaluate a model on a list of tasks.
74

75
76
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
77
78
    :param model_args: Optional[str, dict]
        String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
79
        Ignored if `model` argument is a LM object.
80
    :param tasks: list[Union[str, dict, Task]]
Leo Gao's avatar
Leo Gao committed
81
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
82
83
    :param num_fewshot: int
        Number of examples in few-shot context
84
    :param batch_size: int or str, optional
85
        Batch size for model
86
87
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
88
    :param device: str, optional
89
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
90
91
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
92
93
94
95
96
97
    :param cache_requests: bool, optional
        Speed up evaluation by caching the building of dataset requests. `None` if not caching.
    :param rewrite_requests_cache: bool, optional
        Rewrites all of the request cache if set to `True`. `None` if not desired.
    :param delete_requests_cache: bool, optional
        Deletes all of the request cache if set to `True`. `None` if not desired.
98
99
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
100
101
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
102
103
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
104
    :param write_out: bool
105
106
107
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
108
109
110
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
Baber Abbasi's avatar
Baber Abbasi committed
111
112
    :param predict_only: bool
        If true only model outputs will be generated and returned. Metrics will not be evaluated
113
114
115
116
117
118
    :param random_seed: int
        Random seed for python's random module. If set to None, the seed will not be set.
    :param numpy_random_seed: int
        Random seed for numpy. If set to None, the seed will not be set.
    :param torch_random_seed: int
        Random seed for torch. If set to None, the seed will not be set.
119
120
    :param fewshot_random_seed: int
        Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
Baber Abbasi's avatar
Baber Abbasi committed
121

122
    :return
123
        Dictionary of results
124
    """
125
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
126
    start_date = time.time()
127

128
129
130
131
    if delete_requests_cache:
        eval_logger.info("Deleting requests cache...")
        delete_cache()

132
    seed_message = []
133
134
    if random_seed is not None:
        # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
135
        seed_message.append(f"Setting random seed to {random_seed}")
136
137
138
        random.seed(random_seed)

    if numpy_random_seed is not None:
139
        seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
140
141
142
        np.random.seed(numpy_random_seed)

    if torch_random_seed is not None:
143
        seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
144
145
        torch.manual_seed(torch_random_seed)

146
147
148
    if seed_message:
        eval_logger.info(" | ".join(seed_message))

149
150
    if tasks is None:
        tasks = []
151
152
153
154
    if len(tasks) == 0:
        raise ValueError(
            "No tasks specified, or no tasks found. Please verify the task names."
        )
155

lintangsutawika's avatar
lintangsutawika committed
156
157
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
158
        eval_logger.warning(
159
160
            "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
            "Ensure 'do_sample=True' for non-greedy decoding!"
lintangsutawika's avatar
udate  
lintangsutawika committed
161
        )
lintangsutawika's avatar
lintangsutawika committed
162
163
164
        if gen_kwargs == "":
            gen_kwargs = None

165
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
166
        if model_args is None:
167
            eval_logger.warning("model_args not specified. Using defaults.")
Fabrizio Milo's avatar
Fabrizio Milo committed
168
            model_args = ""
169

170
        if isinstance(model_args, dict):
171
172
173
            eval_logger.info(
                f"Initializing {model} model, with arguments: {model_args}"
            )
174
175
176
177
178
179
180
181
182
183
            lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
                model_args,
                {
                    "batch_size": batch_size,
                    "max_batch_size": max_batch_size,
                    "device": device,
                },
            )

        else:
184
185
186
            eval_logger.info(
                f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
            )
187
188
189
190
191
192
193
194
            lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
                model_args,
                {
                    "batch_size": batch_size,
                    "max_batch_size": max_batch_size,
                    "device": device,
                },
            )
195
    else:
196
197
        if not isinstance(model, lm_eval.api.model.LM):
            raise TypeError
198
        eval_logger.info("Using pre-initialized model")
199
        lm = model
200

haileyschoelkopf's avatar
haileyschoelkopf committed
201
    if use_cache is not None:
202
        eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
haileyschoelkopf's avatar
haileyschoelkopf committed
203
204
205
206
207
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
208
209
210
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
211
212
        )

213
214
215
    if check_integrity:
        run_task_tests(task_list=tasks)

216
217
218
219
    if task_manager is None:
        task_manager = TaskManager(verbosity)

    task_dict = get_task_dict(tasks, task_manager)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
220

lintangsutawika's avatar
lintangsutawika committed
221
    def _adjust_config(task_dict, predict_only):
222
223
224
225
226
        adjusted_task_dict = {}
        for task_name, task_obj in task_dict.items():
            if isinstance(task_obj, dict):
                adjusted_task_dict = {
                    **adjusted_task_dict,
227
                    **{task_name: _adjust_config(task_obj, predict_only)},
228
                }
Stephen Hogg's avatar
Stephen Hogg committed
229

230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
            else:
                if task_obj.get_config("output_type") == "generate_until":
                    if gen_kwargs is not None:
                        task_obj.set_config(
                            key="generation_kwargs", value=gen_kwargs, update=True
                        )

                if predict_only:
                    eval_logger.info(
                        f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
                    )
                    # we have to change the class properties post-hoc. This is pretty hacky.
                    task_obj.override_metric(metric_name="bypass")

                # override tasks' fewshot values to the provided num_fewshot arg value
                # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
                if num_fewshot is not None:
                    if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
                        eval_logger.info(
                            f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                        )
                    else:
                        eval_logger.warning(
                            f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                        )
                        task_obj.set_config(key="num_fewshot", value=num_fewshot)
lintangsutawika's avatar
lintangsutawika committed
256
257
258
259
                    task_obj.set_fewshot_seed(seed=fewshot_random_seed)
                    eval_logger.info(
                        f"Setting fewshot random generator seed to {fewshot_random_seed}"
                    )
260
261
                else:
                    # if num_fewshot not provided, and the task does not define a default one, default to 0
262
263
264
                    if (
                        default_num_fewshot := task_obj.get_config("num_fewshot")
                    ) is None:
265
                        task_obj.set_config(key="num_fewshot", value=0)
266

267
268
269
270
                adjusted_task_dict[task_name] = task_obj

        return adjusted_task_dict

lintangsutawika's avatar
lintangsutawika committed
271
    task_dict = _adjust_config(task_dict, predict_only)
272
273
274
275
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
276
277
        cache_requests=cache_requests,
        rewrite_requests_cache=rewrite_requests_cache,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
278
        bootstrap_iters=bootstrap_iters,
279
        write_out=write_out,
280
        log_samples=True if predict_only else log_samples,
281
        verbosity=verbosity,
282
    )
283

284
    if lm.rank == 0:
285
286
287
288
289
290
291
        if isinstance(model, str):
            model_name = model
        elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
            model_name = model.config._name_or_path
        else:
            model_name = type(model).__name__

292
293
        # add info about the model and few shot config
        results["config"] = {
294
            "model": model_name,
295
296
            "model_args": model_args,
        }
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
        # add more detailed model info if available
        if isinstance(lm, lm_eval.models.huggingface.HFLM):
            results["config"].update(lm.get_model_info())
        # add info about execution
        results["config"].update(
            {
                "batch_size": batch_size,
                "batch_sizes": (
                    list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
                ),
                "device": device,
                "use_cache": use_cache,
                "limit": limit,
                "bootstrap_iters": bootstrap_iters,
                "gen_kwargs": gen_kwargs,
312
313
314
315
                "random_seed": random_seed,
                "numpy_seed": numpy_random_seed,
                "torch_seed": torch_random_seed,
                "fewshot_seed": fewshot_random_seed,
316
317
            }
        )
318
        results["git_hash"] = get_git_commit_hash()
319
        results["date"] = start_date
320
        add_env_info(results)  # additional environment info to results
321
322
323
        return results
    else:
        return None
324

Leo Gao's avatar
Leo Gao committed
325

326
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
327
def evaluate(
328
    lm: "LM",
Fabrizio Milo's avatar
Fabrizio Milo committed
329
    task_dict,
Baber Abbasi's avatar
Baber Abbasi committed
330
    limit: Optional[int] = None,
331
332
    cache_requests: bool = False,
    rewrite_requests_cache: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
333
    bootstrap_iters: Optional[int] = 100000,
Ethan Smith's avatar
Ethan Smith committed
334
335
    write_out: bool = False,
    log_samples: bool = True,
336
    verbosity: str = "INFO",
Fabrizio Milo's avatar
Fabrizio Milo committed
337
):
338
339
340
341
342
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
343
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
344
345
346
347
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
348
    :param write_out: bool
349
350
351
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
352
353
354
    :return
        Dictionary of results
    """
355

356
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
357

358
    # tracks all Instances/requests a model must generate output on.
359
    requests = defaultdict(list)
360
361
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
362
    padding_requests = defaultdict(int)
363

364
    # get lists of group hierarchy and each type of request
365
    eval_tasks = get_task_list(task_dict)
366
    if not log_samples:
367
        if not all(
368
369
            "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
            for task_output in eval_tasks
370
371
        ):
            raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
372
373
374
    for task_output in eval_tasks:
        task: Task = task_output.task
        limit = get_sample_size(task, limit)
375
376
377
378
379
380
381
        task.build_all_requests(
            limit=limit,
            rank=lm.rank,
            world_size=lm.world_size,
            cache_requests=cache_requests,
            rewrite_requests_cache=rewrite_requests_cache,
        )
382
        eval_logger.debug(
383
            f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
haileyschoelkopf's avatar
haileyschoelkopf committed
384
385
        )
        if write_out:
386
            print_writeout(task)
387
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
388
389
390
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
391
392

        if lm.world_size > 1:
393
394
395
396
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
397
398
399
400
401
402
            # "multiple_choice" task types dispatch (several) "loglikelihood" request types
            reqtype = (
                "loglikelihood"
                if task.OUTPUT_TYPE == "multiple_choice"
                else task.OUTPUT_TYPE
            )
403
            # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
404
            numpad = max(gathered_item) - gathered_item[lm.rank]
405
406
            # todo: may not account for padding in cases like SquadV2 which has multiple req types
            padding_requests[reqtype] += numpad
407

408
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
409
410
    # execute each type of request
    for reqtype, reqs in requests.items():
411
        eval_logger.info(f"Running {reqtype} requests")
412
413
414
415
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
416

417
418
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
419
420
                cloned_reqs.extend([req] * req.repeats)

421
422
423
424
425
426
427
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

428
429
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
430

431
432
    RANK = lm.rank
    WORLD_SIZE = lm.world_size
433
434
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
435
436
    for task_output in eval_tasks:
        task = task_output.task
437
438
        task.apply_filters()

439
440
        ### Collect values of metrics on all datapoints ###
        # # unpack results and sort back in order and return control to Task
haileyschoelkopf's avatar
haileyschoelkopf committed
441
        # TODO: make it possible to use a different metric per filter
442
        # Pre-process task.instances to group by doc_id
443
        instances_by_doc_id = defaultdict(list)
444
445
446
447
448
        for instance in task.instances:
            instances_by_doc_id[instance.doc_id].append(instance)
        # Sort instances within each group
        for instances in instances_by_doc_id.values():
            instances.sort(key=lambda x: x.idx)
haileyschoelkopf's avatar
haileyschoelkopf committed
449
        # iterate over different filters used
450
451
452
        for filter_key in task.instances[0].filtered_resps.keys():
            doc_iterator = task.doc_iterator(
                rank=RANK, limit=limit, world_size=WORLD_SIZE
453
            )
454
            for doc_id, doc in doc_iterator:
455
                requests = instances_by_doc_id[doc_id]
lintangsutawika's avatar
lintangsutawika committed
456
                metrics = task.process_results(
457
                    doc, [req.filtered_resps[filter_key] for req in requests]
lintangsutawika's avatar
lintangsutawika committed
458
                )
459
460
461
462
463
464
465
466
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
467
468
469
                        "filtered_resps": [
                            req.filtered_resps[filter_key] for req in requests
                        ],
470
471
472
473
474
475
476
477
478
479
                        "doc_hash": hash_string(
                            json.dumps(
                                requests[0].doc,
                                indent=2,
                                default=handle_non_serializable,
                                ensure_ascii=False,
                            )
                        ),
                        "prompt_hash": hash_string(requests[0].arguments[0]),
                        "target_hash": hash_string(str(target)),
480
481
                    }
                    example.update(metrics)
482
                    task_output.logged_samples.append(example)
483
                for metric, value in metrics.items():
484
                    task_output.sample_metrics[(metric, filter_key)].append(value)
485

486
487
    if WORLD_SIZE > 1:
        # if multigpu, then gather data across all ranks to rank 0
488
        # first gather logged samples across all ranks
489
490
491
492
493
494
495
496
        for task_output in eval_tasks:
            if log_samples:
                # for task_name, task_samples in list(samples.items()):
                full_samples = [None] * WORLD_SIZE if RANK == 0 else None
                torch.distributed.gather_object(
                    obj=task_output.logged_samples,
                    object_gather_list=full_samples,
                    dst=0,
497
                )
498

499
500
501
502
                if RANK == 0:
                    task_output.logged_samples = list(
                        itertools.chain.from_iterable(full_samples)
                    )
503

504
505
506
507
508
509
510
            # then collect metrics across all ranks
            for metrics in task_output.sample_metrics:
                metric_list = [None] * WORLD_SIZE if RANK == 0 else None
                torch.distributed.gather_object(
                    obj=task_output.sample_metrics[metrics],
                    object_gather_list=metric_list,
                    dst=0,
511
                )
512
513
514
515
                if RANK == 0:
                    task_output.sample_metrics[metrics] = list(
                        itertools.chain.from_iterable(metric_list)
                    )
516

517
    if RANK == 0:
518
519
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
520
521
522
523
524
        for task_output in eval_tasks:
            task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
        results, samples, configs, versions, num_fewshot = consolidate_results(
            eval_tasks
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
525

526
        ### Calculate group metrics ###
lintangsutawika's avatar
lintangsutawika committed
527
        if bool(results):
528

529
530
            def process_group(
                results,
lintangsutawika's avatar
lintangsutawika committed
531
                versions,
532
533
534
535
536
                task_dict,
                task_root=None,
                task_hierarchy=None,
                show_group_table=False,
            ):
537
538
539
540
541
542
543
                if task_root is None:
                    task_root = {}

                if task_hierarchy is None:
                    task_hierarchy = {}

                for group_or_task, group_or_task_info in task_dict.items():
544
545
                    
                    # Convert to string
546
547
548
                    if isinstance(group_or_task, ConfigurableGroup):
                        group_config = group_or_task.config
                        group_or_task = group_or_task.group
549
550
                    else:
                        group_config = None
551

552
553
                    if isinstance(group_or_task_info, ConfigurableTask):
                        if task_root:
554
555
556
                            task_hierarchy.setdefault(task_root, []).append(
                                group_or_task
                            )
557
                    else:
lintangsutawika's avatar
lintangsutawika committed
558
                        results, versions, _task_hierarchy, show_group_table = process_group(
559
                            results,
lintangsutawika's avatar
lintangsutawika committed
560
                            versions,
561
562
563
564
565
                            group_or_task_info,
                            group_or_task,
                            task_hierarchy,
                            show_group_table,
                        )
566
                        if task_root:
567
568
569
                            task_hierarchy.setdefault(task_root, []).extend(
                                task_hierarchy.get(group_or_task, [])
                            )
570

571
572
573
574
575
576
577
578
                        if (group_config is not None) and (group_config["aggregate_metric"] is False):
                            results[group_or_task][" "] = " "
                            continue

                        show_group_table = (
                            show_group_table | group_config["aggregate_metric"]
                        )

579
580
581
582
583
584
                        task_list = _task_hierarchy[group_or_task]
                        metric_list = list(
                            {
                                key
                                for task in task_list
                                for key in results[task].keys()
585
586
                                if "_stderr" not in key
                                and key not in ["alias", "samples"]
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
                            }
                        )
                        for metric in metric_list:
                            stderr = "_stderr,".join(metric.split(","))

                            # gather metrics, sizes, and stderrs from subtasks
                            metrics = [
                                results[task][metric]
                                for task in task_list
                                if metric in results[task]
                            ]  # TODO: copy?
                            stderrs = [
                                results[task][stderr]
                                for task in task_list
                                if stderr in results[task]
                            ]
                            sizes = [
                                results[task]["samples"]
                                for task in task_list
                                if metric in results[task]
                            ]

                            # compute group's pooled metric and stderr
610
                            results[group_or_task][
611
612
613
614
615
616
617
618
                                metric
                            ] = lm_eval.api.metrics.aggregate_subtask_metrics(
                                metrics,
                                sizes,
                                group_config["weight_by_size"],
                            )
                            # TODO: calculate grouped metric using aggregation fn
                            if "N/A" in stderrs:
619
                                results[group_or_task][stderr] = "N/A"
620
                            else:
621
                                results[group_or_task][
622
                                    stderr
623
624
625
                                ] = lm_eval.api.metrics.pooled_sample_stderr(
                                    stderrs, sizes
                                )
626
627
628
629
                                # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
                                # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
                                # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)

630
                            results[group_or_task]["samples"] = sum(sizes)
lintangsutawika's avatar
lintangsutawika committed
631
632
                            versions[group_or_task] = group_config["version"]
                return results, versions, task_hierarchy, show_group_table
633

lintangsutawika's avatar
lintangsutawika committed
634
635
            results, versions, task_hierarchy, show_group_table = process_group(
                results, versions, task_dict
636
637
            )

638
        results_agg, group_agg = prepare_print_tasks(task_dict, results)
639
        results_dict = {
640
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
641
642
643
644
645
            **(
                {"groups": dict(group_agg.items())}
                if (bool(group_agg) & show_group_table)
                else {}
            ),
646
            "group_subtasks": dict(reversed(task_hierarchy.items())),
647
648
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
649
            "n-shot": dict(sorted(num_fewshot.items())),
650
651
652
            "n-samples": {
                task_output.task_name: {
                    "original": len(task_output.task.eval_docs),
KonradSzafer's avatar
KonradSzafer committed
653
654
655
656
                    "effective": min(
                        limit if limit else len(task_output.task.eval_docs),
                        len(task_output.task.eval_docs),
                    ),
657
658
659
                }
                for task_output in eval_tasks
            },
660
        }
661
662
663
664
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
665

666
667
    else:
        return None
668
669
670
671


def request_caching_arg_to_dict(cache_requests: str) -> dict:
    request_caching_args = {
672
673
674
        "cache_requests": cache_requests in {"true", "refresh"},
        "rewrite_requests_cache": cache_requests == "refresh",
        "delete_requests_cache": cache_requests == "delete",
675
676
677
    }

    return request_caching_args