evaluator.py 27.3 KB
Newer Older
Baber Abbasi's avatar
Baber Abbasi committed
1
import itertools
2
import json
3
import logging
Baber Abbasi's avatar
Baber Abbasi committed
4
import random
5
import time
6
7
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
Baber Abbasi's avatar
Baber Abbasi committed
8

9
import numpy as np
Baber Abbasi's avatar
Baber Abbasi committed
10
import torch
lintangsutawika's avatar
lintangsutawika committed
11

lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
13
import lm_eval.api.registry
Baber Abbasi's avatar
Baber Abbasi committed
14
import lm_eval.models
15
from lm_eval.caching.cache import delete_cache
16
17
18
19
20
21
22
23
from lm_eval.evaluator_utils import (
    consolidate_results,
    get_sample_size,
    get_task_list,
    prepare_print_tasks,
    print_writeout,
    run_task_tests,
)
KonradSzafer's avatar
KonradSzafer committed
24
from lm_eval.loggers import EvaluationTracker
25
from lm_eval.loggers.utils import add_env_info, get_git_commit_hash
Baber Abbasi's avatar
Baber Abbasi committed
26
from lm_eval.tasks import TaskManager, get_task_dict
27
28
29
30
31
32
33
from lm_eval.utils import (
    eval_logger,
    handle_non_serializable,
    hash_string,
    positional_deprecated,
    simple_parse_args_string,
)
34

Fabrizio Milo's avatar
Fabrizio Milo committed
35

36
37
38
39
40
if TYPE_CHECKING:
    from lm_eval.api.model import LM
    from lm_eval.tasks import Task


41
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
42
43
def simple_evaluate(
    model,
44
45
    model_args: Optional[Union[str, dict]] = None,
    tasks: Optional[List[Union[str, dict, object]]] = None,
Baber Abbasi's avatar
Baber Abbasi committed
46
47
48
49
50
    num_fewshot: Optional[int] = None,
    batch_size: Optional[int] = None,
    max_batch_size: Optional[int] = None,
    device: Optional[str] = None,
    use_cache: Optional[str] = None,
51
52
53
    cache_requests: bool = False,
    rewrite_requests_cache: bool = False,
    delete_requests_cache: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
54
    limit: Optional[Union[int, float]] = None,
Ethan Smith's avatar
Ethan Smith committed
55
56
57
58
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
    write_out: bool = False,
    log_samples: bool = True,
KonradSzafer's avatar
KonradSzafer committed
59
60
61
62
    evaluation_tracker: Optional[EvaluationTracker] = None,
    system_instruction: Optional[str] = None,
    apply_chat_template: bool = False,
    fewshot_as_multiturn: bool = False,
63
64
    gen_kwargs: Optional[str] = None,
    task_manager: Optional[TaskManager] = None,
65
    verbosity: str = "INFO",
Baber Abbasi's avatar
Baber Abbasi committed
66
    predict_only: bool = False,
67
68
69
    random_seed: int = 0,
    numpy_random_seed: int = 1234,
    torch_random_seed: int = 1234,
70
    fewshot_random_seed: int = 1234,
Fabrizio Milo's avatar
Fabrizio Milo committed
71
):
72
    """Instantiate and evaluate a model on a list of tasks.
73

74
75
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
76
77
    :param model_args: Optional[str, dict]
        String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
78
        Ignored if `model` argument is a LM object.
79
    :param tasks: list[Union[str, dict, Task]]
Leo Gao's avatar
Leo Gao committed
80
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
81
82
    :param num_fewshot: int
        Number of examples in few-shot context
83
    :param batch_size: int or str, optional
84
        Batch size for model
85
86
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
87
    :param device: str, optional
88
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
89
90
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
91
92
93
94
95
96
    :param cache_requests: bool, optional
        Speed up evaluation by caching the building of dataset requests. `None` if not caching.
    :param rewrite_requests_cache: bool, optional
        Rewrites all of the request cache if set to `True`. `None` if not desired.
    :param delete_requests_cache: bool, optional
        Deletes all of the request cache if set to `True`. `None` if not desired.
97
98
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
99
    :param bootstrap_iters:
100
        Number of iterations for bootstrap statistics, used when calculating stderrs. set to 0 for no stderr calculations to be performed.
Stephen Hogg's avatar
Stephen Hogg committed
101
102
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
103
    :param write_out: bool
104
105
106
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
KonradSzafer's avatar
KonradSzafer committed
107
108
109
110
111
112
    :param system_instruction: str
        System instruction to be applied to the prompt
    :param apply_chat_template: bool
        If True, apply chat template to the prompt
    :param fewshot_as_multiturn: bool
        Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
113
114
115
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
Baber Abbasi's avatar
Baber Abbasi committed
116
117
    :param predict_only: bool
        If true only model outputs will be generated and returned. Metrics will not be evaluated
118
119
120
121
122
123
    :param random_seed: int
        Random seed for python's random module. If set to None, the seed will not be set.
    :param numpy_random_seed: int
        Random seed for numpy. If set to None, the seed will not be set.
    :param torch_random_seed: int
        Random seed for torch. If set to None, the seed will not be set.
124
125
    :param fewshot_random_seed: int
        Random seed for fewshot sampler random generator. If set to None, the seed of generator will be set to None.
Baber Abbasi's avatar
Baber Abbasi committed
126

127
    :return
128
        Dictionary of results
129
    """
130
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
131
    start_date = time.time()
132

133
134
135
136
    if delete_requests_cache:
        eval_logger.info("Deleting requests cache...")
        delete_cache()

137
    seed_message = []
138
139
    if random_seed is not None:
        # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
140
        seed_message.append(f"Setting random seed to {random_seed}")
141
142
143
        random.seed(random_seed)

    if numpy_random_seed is not None:
144
        seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
145
146
147
        np.random.seed(numpy_random_seed)

    if torch_random_seed is not None:
148
        seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
149
150
        torch.manual_seed(torch_random_seed)

151
152
153
    if seed_message:
        eval_logger.info(" | ".join(seed_message))

154
155
    if tasks is None:
        tasks = []
156
157
158
159
    if len(tasks) == 0:
        raise ValueError(
            "No tasks specified, or no tasks found. Please verify the task names."
        )
160

lintangsutawika's avatar
lintangsutawika committed
161
162
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
163
        eval_logger.warning(
164
165
            "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
            "Ensure 'do_sample=True' for non-greedy decoding!"
lintangsutawika's avatar
udate  
lintangsutawika committed
166
        )
lintangsutawika's avatar
lintangsutawika committed
167
168
169
        if gen_kwargs == "":
            gen_kwargs = None

170
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
171
        if model_args is None:
172
            eval_logger.warning("model_args not specified. Using defaults.")
Fabrizio Milo's avatar
Fabrizio Milo committed
173
            model_args = ""
174

175
        if isinstance(model_args, dict):
176
177
178
            eval_logger.info(
                f"Initializing {model} model, with arguments: {model_args}"
            )
179
180
181
182
183
184
185
186
187
188
            lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
                model_args,
                {
                    "batch_size": batch_size,
                    "max_batch_size": max_batch_size,
                    "device": device,
                },
            )

        else:
189
190
191
            eval_logger.info(
                f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
            )
192
193
194
195
196
197
198
199
            lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
                model_args,
                {
                    "batch_size": batch_size,
                    "max_batch_size": max_batch_size,
                    "device": device,
                },
            )
200
    else:
201
202
        if not isinstance(model, lm_eval.api.model.LM):
            raise TypeError
203
        eval_logger.info("Using pre-initialized model")
204
        lm = model
205

haileyschoelkopf's avatar
haileyschoelkopf committed
206
    if use_cache is not None:
207
        eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
haileyschoelkopf's avatar
haileyschoelkopf committed
208
209
210
211
212
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
213
214
215
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
216
217
        )

218
219
220
221
    if task_manager is None:
        task_manager = TaskManager(verbosity)

    task_dict = get_task_dict(tasks, task_manager)
222
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
223
        task_obj = task_dict[task_name]
224
        if isinstance(task_obj, tuple):
225
            _, task_obj = task_obj
226
227
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
228

Baber Abbasi's avatar
Baber Abbasi committed
229
230
        if task_obj.get_config("output_type") == "generate_until":
            if gen_kwargs is not None:
Baber Abbasi's avatar
Baber Abbasi committed
231
                task_obj.set_config(
Baber Abbasi's avatar
Baber Abbasi committed
232
233
234
                    key="generation_kwargs", value=gen_kwargs, update=True
                )

235
236
237
238
239
240
241
        if predict_only:
            log_samples = True
            eval_logger.info(
                f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
            )
            # we have to change the class properties post-hoc. This is pretty hacky.
            task_obj.override_metric(metric_name="bypass")
242

243
244
        # override tasks' fewshot values to the provided num_fewshot arg value
        # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
245
        if num_fewshot is not None:
Baber Abbasi's avatar
Baber Abbasi committed
246
            if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
247
248
249
                eval_logger.info(
                    f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                )
250
            else:
Baber Abbasi's avatar
Baber Abbasi committed
251
252
253
                eval_logger.warning(
                    f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                )
Baber Abbasi's avatar
Baber Abbasi committed
254
                task_obj.set_config(key="num_fewshot", value=num_fewshot)
255
256
257
258
        else:
            # if num_fewshot not provided, and the task does not define a default one, default to 0
            if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None:
                task_obj.set_config(key="num_fewshot", value=0)
259
260
261
262
263
        # fewshot_random_seed set for tasks, even with a default num_fewshot (e.g. in the YAML file)
        task_obj.set_fewshot_seed(seed=fewshot_random_seed)
        eval_logger.info(
            f"Setting fewshot random generator seed to {fewshot_random_seed}"
        )
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
264

Stephen Hogg's avatar
Stephen Hogg committed
265
    if check_integrity:
266
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
267

KonradSzafer's avatar
KonradSzafer committed
268
269
270
271
272
273
274
275
    if evaluation_tracker is not None:
        evaluation_tracker.general_config_tracker.log_experiment_args(
            model_source=model,
            model_args=model_args,
            system_instruction=system_instruction,
            chat_template=lm.chat_template if apply_chat_template else None,
        )

276
277
278
279
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
280
281
        cache_requests=cache_requests,
        rewrite_requests_cache=rewrite_requests_cache,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
282
        bootstrap_iters=bootstrap_iters,
283
        write_out=write_out,
284
        log_samples=log_samples,
KonradSzafer's avatar
KonradSzafer committed
285
286
287
        system_instruction=system_instruction,
        apply_chat_template=apply_chat_template,
        fewshot_as_multiturn=fewshot_as_multiturn,
288
        verbosity=verbosity,
289
    )
290

291
    if lm.rank == 0:
292
293
294
295
296
297
298
        if isinstance(model, str):
            model_name = model
        elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
            model_name = model.config._name_or_path
        else:
            model_name = type(model).__name__

299
300
        # add info about the model and few shot config
        results["config"] = {
301
            "model": model_name,
302
303
            "model_args": model_args,
        }
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
        # add more detailed model info if available
        if isinstance(lm, lm_eval.models.huggingface.HFLM):
            results["config"].update(lm.get_model_info())
        # add info about execution
        results["config"].update(
            {
                "batch_size": batch_size,
                "batch_sizes": (
                    list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
                ),
                "device": device,
                "use_cache": use_cache,
                "limit": limit,
                "bootstrap_iters": bootstrap_iters,
                "gen_kwargs": gen_kwargs,
319
320
321
322
                "random_seed": random_seed,
                "numpy_seed": numpy_random_seed,
                "torch_seed": torch_random_seed,
                "fewshot_seed": fewshot_random_seed,
323
324
            }
        )
325
        results["git_hash"] = get_git_commit_hash()
326
        results["date"] = start_date
327
        add_env_info(results)  # additional environment info to results
328
329
330
        return results
    else:
        return None
331

Leo Gao's avatar
Leo Gao committed
332

333
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
334
def evaluate(
335
    lm: "LM",
Fabrizio Milo's avatar
Fabrizio Milo committed
336
    task_dict,
Baber Abbasi's avatar
Baber Abbasi committed
337
    limit: Optional[int] = None,
338
339
    cache_requests: bool = False,
    rewrite_requests_cache: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
340
    bootstrap_iters: Optional[int] = 100000,
Ethan Smith's avatar
Ethan Smith committed
341
342
    write_out: bool = False,
    log_samples: bool = True,
KonradSzafer's avatar
KonradSzafer committed
343
344
345
    system_instruction: Optional[str] = None,
    apply_chat_template: bool = False,
    fewshot_as_multiturn: bool = False,
346
    verbosity: str = "INFO",
Fabrizio Milo's avatar
Fabrizio Milo committed
347
):
348
349
350
351
352
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
353
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
354
355
356
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
357
        Number of iterations for bootstrap statistics, used when calculating stderr. Set to 0 for skipping all stderr calculations.
358
    :param write_out: bool
359
360
361
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
KonradSzafer's avatar
KonradSzafer committed
362
363
364
365
366
367
    :param system_instruction: str
        System instruction to be applied to the prompt
    :param apply_chat_template: bool
        If True, apply chat template to the prompt
    :param fewshot_as_multiturn: bool
        Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
368
369
370
    :return
        Dictionary of results
    """
371

372
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
373

374
    # tracks all Instances/requests a model must generate output on.
375
    requests = defaultdict(list)
376
377
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
378
    padding_requests = defaultdict(int)
379

380
381
382
    # get lists of group hierarchy and each type of request
    task_hierarchy, eval_tasks = get_task_list(task_dict)
    if not log_samples:
383
        if not all(
384
385
            "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
            for task_output in eval_tasks
386
387
        ):
            raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
388
389
390
    for task_output in eval_tasks:
        task: Task = task_output.task
        limit = get_sample_size(task, limit)
391
392
393
394
395
396
        task.build_all_requests(
            limit=limit,
            rank=lm.rank,
            world_size=lm.world_size,
            cache_requests=cache_requests,
            rewrite_requests_cache=rewrite_requests_cache,
KonradSzafer's avatar
KonradSzafer committed
397
398
399
400
            system_instruction=system_instruction,
            apply_chat_template=apply_chat_template,
            fewshot_as_multiturn=fewshot_as_multiturn,
            lm=lm,
401
        )
402
        eval_logger.debug(
403
            f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
haileyschoelkopf's avatar
haileyschoelkopf committed
404
405
        )
        if write_out:
406
            print_writeout(task)
407
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
408
409
410
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
411
412

        if lm.world_size > 1:
413
414
415
416
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
417
418
419
420
421
422
            # "multiple_choice" task types dispatch (several) "loglikelihood" request types
            reqtype = (
                "loglikelihood"
                if task.OUTPUT_TYPE == "multiple_choice"
                else task.OUTPUT_TYPE
            )
423
            # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
424
            numpad = max(gathered_item) - gathered_item[lm.rank]
425
426
            # todo: may not account for padding in cases like SquadV2 which has multiple req types
            padding_requests[reqtype] += numpad
427

428
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
429
430
    # execute each type of request
    for reqtype, reqs in requests.items():
431
        eval_logger.info(f"Running {reqtype} requests")
432
433
434
435
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
436

437
438
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
439
440
                cloned_reqs.extend([req] * req.repeats)

441
442
443
444
445
446
447
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

448
449
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
450

451
452
    RANK = lm.rank
    WORLD_SIZE = lm.world_size
453
454
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
455
456
    for task_output in eval_tasks:
        task = task_output.task
457
458
        task.apply_filters()

459
460
        ### Collect values of metrics on all datapoints ###
        # # unpack results and sort back in order and return control to Task
haileyschoelkopf's avatar
haileyschoelkopf committed
461
        # TODO: make it possible to use a different metric per filter
462
        # Pre-process task.instances to group by doc_id
463
        instances_by_doc_id = defaultdict(list)
464
465
466
467
468
        for instance in task.instances:
            instances_by_doc_id[instance.doc_id].append(instance)
        # Sort instances within each group
        for instances in instances_by_doc_id.values():
            instances.sort(key=lambda x: x.idx)
haileyschoelkopf's avatar
haileyschoelkopf committed
469
        # iterate over different filters used
470
471
472
        for filter_key in task.instances[0].filtered_resps.keys():
            doc_iterator = task.doc_iterator(
                rank=RANK, limit=limit, world_size=WORLD_SIZE
473
            )
474
            for doc_id, doc in doc_iterator:
475
                requests = instances_by_doc_id[doc_id]
lintangsutawika's avatar
lintangsutawika committed
476
                metrics = task.process_results(
477
                    doc, [req.filtered_resps[filter_key] for req in requests]
lintangsutawika's avatar
lintangsutawika committed
478
                )
479
480
481
482
483
484
485
486
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
487
488
489
                        "filtered_resps": [
                            req.filtered_resps[filter_key] for req in requests
                        ],
490
491
492
493
494
495
496
497
498
499
                        "doc_hash": hash_string(
                            json.dumps(
                                requests[0].doc,
                                indent=2,
                                default=handle_non_serializable,
                                ensure_ascii=False,
                            )
                        ),
                        "prompt_hash": hash_string(requests[0].arguments[0]),
                        "target_hash": hash_string(str(target)),
500
501
                    }
                    example.update(metrics)
502
                    task_output.logged_samples.append(example)
503
                for metric, value in metrics.items():
504
                    task_output.sample_metrics[(metric, filter_key)].append(value)
505

506
507
    if WORLD_SIZE > 1:
        # if multigpu, then gather data across all ranks to rank 0
508
        # first gather logged samples across all ranks
509
510
511
512
513
514
515
516
        for task_output in eval_tasks:
            if log_samples:
                # for task_name, task_samples in list(samples.items()):
                full_samples = [None] * WORLD_SIZE if RANK == 0 else None
                torch.distributed.gather_object(
                    obj=task_output.logged_samples,
                    object_gather_list=full_samples,
                    dst=0,
517
                )
518

519
520
521
522
                if RANK == 0:
                    task_output.logged_samples = list(
                        itertools.chain.from_iterable(full_samples)
                    )
523

524
525
526
527
528
529
530
            # then collect metrics across all ranks
            for metrics in task_output.sample_metrics:
                metric_list = [None] * WORLD_SIZE if RANK == 0 else None
                torch.distributed.gather_object(
                    obj=task_output.sample_metrics[metrics],
                    object_gather_list=metric_list,
                    dst=0,
531
                )
532
533
534
535
                if RANK == 0:
                    task_output.sample_metrics[metrics] = list(
                        itertools.chain.from_iterable(metric_list)
                    )
536

537
    if RANK == 0:
538
539
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
540
541
        for task_output in eval_tasks:
            task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
542
543
544
545
546
547
548
549
        (
            results,
            samples,
            configs,
            versions,
            num_fewshot,
            higher_is_better,
        ) = consolidate_results(eval_tasks)
Fabrizio Milo's avatar
Fabrizio Milo committed
550

551
        ### Calculate group metrics ###
lintangsutawika's avatar
lintangsutawika committed
552
        if bool(results):
553
            for group, task_list in reversed(task_hierarchy.items()):
554
555
556
557
558
559
                if len(task_list) == 0:
                    # task_hierarchy entries are either
                    # `group_name: [subtask1, subtask2, ...]`
                    # or `task_name: []`.
                    # we only want to operate on groups here.
                    continue
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580

                # collect all higher_is_better values for metrics
                # in the group's subtasks.
                # TODO: clean this up ; unify with the below metric_list loop?
                _higher_is_better = {}
                for task in task_list:
                    for m, h in higher_is_better[task].items():
                        if m not in _higher_is_better.keys():
                            _higher_is_better[m] = h
                    if (
                        m in _higher_is_better
                        and _higher_is_better[m] is not None
                        and _higher_is_better[m] != h
                    ):
                        eval_logger.warning(
                            f"Higher_is_better values for metric {m} in group {group} are not consistent. Defaulting to None."
                        )
                        _higher_is_better[m] = None
                higher_is_better[group] = _higher_is_better

                # collect all metric keys used by a subtask in the group.
581
582
583
584
585
586
587
588
589
                metric_list = list(
                    {
                        key
                        for task in task_list
                        for key in results[task].keys()
                        if "_stderr" not in key and key not in ["alias", "samples"]
                    }
                )
                for metric in metric_list:
590
591
592
                    stderr = "_stderr,".join(metric.split(","))

                    # gather metrics, sizes, and stderrs from subtasks
Baber Abbasi's avatar
Baber Abbasi committed
593
                    metrics = [
594
595
596
                        results[task][metric]
                        for task in task_list
                        if metric in results[task]
Baber Abbasi's avatar
Baber Abbasi committed
597
                    ]  # TODO: copy?
598
599
600
601
602
603
604
605
606
607
                    stderrs = [
                        results[task][stderr]
                        for task in task_list
                        if stderr in results[task]
                    ]
                    sizes = [
                        results[task]["samples"]
                        for task in task_list
                        if metric in results[task]
                    ]
608
609

                    # compute group's pooled metric and stderr
Baber Abbasi's avatar
Baber Abbasi committed
610
611
612
                    results[group][
                        metric
                    ] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
613
614
615
616
                    # TODO: calculate grouped metric using aggregation fn
                    if "N/A" in stderrs:
                        results[group][stderr] = "N/A"
                    else:
Baber Abbasi's avatar
Baber Abbasi committed
617
618
619
                        results[group][
                            stderr
                        ] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
620
621
622
623
624
                        # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
                        # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
                        # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)

                    results[group]["samples"] = sum(sizes)
lintangsutawika's avatar
lintangsutawika committed
625

626
627
        results_agg = defaultdict(dict)
        groups_agg = defaultdict(dict)
Lintang Sutawika's avatar
Lintang Sutawika committed
628
629
630
631
632
633
634
635
636
637
        all_tasks_list = list(task_hierarchy.keys())
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
                k: v for k, v in task_hierarchy.items() if k in left_tasks_list
            }
638
            _results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results)
Lintang Sutawika's avatar
Lintang Sutawika committed
639
640
641

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
642

643
        for group_name, task_list in task_hierarchy.items():
Baber Abbasi's avatar
Baber Abbasi committed
644
645
646
647
            if task_list:
                num_fewshot[group_name] = num_fewshot[
                    task_list[0]
                ]  # TODO: validate this
648

649
        results_dict = {
650
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
651
            **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
652
            "group_subtasks": dict(reversed(task_hierarchy.items())),
653
654
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
655
            "n-shot": dict(sorted(num_fewshot.items())),
656
            "higher_is_better": dict(sorted(higher_is_better.items())),
657
658
659
            "n-samples": {
                task_output.task_name: {
                    "original": len(task_output.task.eval_docs),
KonradSzafer's avatar
KonradSzafer committed
660
661
662
663
                    "effective": min(
                        limit if limit else len(task_output.task.eval_docs),
                        len(task_output.task.eval_docs),
                    ),
664
665
666
                }
                for task_output in eval_tasks
            },
667
        }
668
669
670
671
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
672

673
674
    else:
        return None
675
676
677
678


def request_caching_arg_to_dict(cache_requests: str) -> dict:
    request_caching_args = {
679
680
681
        "cache_requests": cache_requests in {"true", "refresh"},
        "rewrite_requests_cache": cache_requests == "refresh",
        "delete_requests_cache": cache_requests == "delete",
682
683
684
    }

    return request_caching_args