evaluator.py 25.4 KB
Newer Older
Baber Abbasi's avatar
Baber Abbasi committed
1
import itertools
2
import logging
Baber Abbasi's avatar
Baber Abbasi committed
3
import random
4
import time
5
6
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
Baber Abbasi's avatar
Baber Abbasi committed
7

8
import numpy as np
Baber Abbasi's avatar
Baber Abbasi committed
9
import torch
lintangsutawika's avatar
lintangsutawika committed
10

lintangsutawika's avatar
lintangsutawika committed
11
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.registry
13
import lm_eval.api.task
Baber Abbasi's avatar
Baber Abbasi committed
14
import lm_eval.models
15
from lm_eval.caching.cache import delete_cache
16
17
18
19
20
21
22
23
from lm_eval.evaluator_utils import (
    consolidate_results,
    get_sample_size,
    get_task_list,
    prepare_print_tasks,
    print_writeout,
    run_task_tests,
)
24
from lm_eval.logging_utils import add_env_info, get_git_commit_hash
25
from lm_eval.tasks import ConfigurableGroup, ConfigurableTask, TaskManager, get_task_dict
26
from lm_eval.utils import eval_logger, positional_deprecated, simple_parse_args_string
27

Fabrizio Milo's avatar
Fabrizio Milo committed
28

29
30
31
32
33
if TYPE_CHECKING:
    from lm_eval.api.model import LM
    from lm_eval.tasks import Task


34
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
35
36
def simple_evaluate(
    model,
37
38
    model_args: Optional[Union[str, dict]] = None,
    tasks: Optional[List[Union[str, dict, object]]] = None,
Baber Abbasi's avatar
Baber Abbasi committed
39
40
41
42
43
    num_fewshot: Optional[int] = None,
    batch_size: Optional[int] = None,
    max_batch_size: Optional[int] = None,
    device: Optional[str] = None,
    use_cache: Optional[str] = None,
44
45
46
    cache_requests: bool = False,
    rewrite_requests_cache: bool = False,
    delete_requests_cache: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
47
    limit: Optional[Union[int, float]] = None,
Ethan Smith's avatar
Ethan Smith committed
48
49
50
51
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
    write_out: bool = False,
    log_samples: bool = True,
52
53
    gen_kwargs: Optional[str] = None,
    task_manager: Optional[TaskManager] = None,
54
    verbosity: str = "INFO",
Baber Abbasi's avatar
Baber Abbasi committed
55
    predict_only: bool = False,
56
57
58
    random_seed: int = 0,
    numpy_random_seed: int = 1234,
    torch_random_seed: int = 1234,
Fabrizio Milo's avatar
Fabrizio Milo committed
59
):
60
    """Instantiate and evaluate a model on a list of tasks.
61

62
63
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
64
65
    :param model_args: Optional[str, dict]
        String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
66
        Ignored if `model` argument is a LM object.
67
    :param tasks: list[Union[str, dict, Task]]
Leo Gao's avatar
Leo Gao committed
68
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
69
70
    :param num_fewshot: int
        Number of examples in few-shot context
71
    :param batch_size: int or str, optional
72
        Batch size for model
73
74
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
75
    :param device: str, optional
76
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
77
78
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
79
80
81
82
83
84
    :param cache_requests: bool, optional
        Speed up evaluation by caching the building of dataset requests. `None` if not caching.
    :param rewrite_requests_cache: bool, optional
        Rewrites all of the request cache if set to `True`. `None` if not desired.
    :param delete_requests_cache: bool, optional
        Deletes all of the request cache if set to `True`. `None` if not desired.
85
86
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
87
88
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
89
90
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
91
    :param write_out: bool
92
93
94
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
95
96
97
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
Baber Abbasi's avatar
Baber Abbasi committed
98
99
    :param predict_only: bool
        If true only model outputs will be generated and returned. Metrics will not be evaluated
100
101
102
103
104
105
    :param random_seed: int
        Random seed for python's random module. If set to None, the seed will not be set.
    :param numpy_random_seed: int
        Random seed for numpy. If set to None, the seed will not be set.
    :param torch_random_seed: int
        Random seed for torch. If set to None, the seed will not be set.
Baber Abbasi's avatar
Baber Abbasi committed
106

107
    :return
108
        Dictionary of results
109
    """
110
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
111
    start_date = time.time()
112

113
114
115
116
    if delete_requests_cache:
        eval_logger.info("Deleting requests cache...")
        delete_cache()

117
    seed_message = []
118
119
    if random_seed is not None:
        # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
120
        seed_message.append(f"Setting random seed to {random_seed}")
121
122
123
        random.seed(random_seed)

    if numpy_random_seed is not None:
124
        seed_message.append(f"Setting numpy seed to {numpy_random_seed}")
125
126
127
        np.random.seed(numpy_random_seed)

    if torch_random_seed is not None:
128
        seed_message.append(f"Setting torch manual seed to {torch_random_seed}")
129
130
        torch.manual_seed(torch_random_seed)

131
132
133
    if seed_message:
        eval_logger.info(" | ".join(seed_message))

134
135
    if tasks is None:
        tasks = []
136
137
138
139
    if len(tasks) == 0:
        raise ValueError(
            "No tasks specified, or no tasks found. Please verify the task names."
        )
140

lintangsutawika's avatar
lintangsutawika committed
141
142
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
143
        eval_logger.warning(
144
145
            "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. "
            "Ensure 'do_sample=True' for non-greedy decoding!"
lintangsutawika's avatar
udate  
lintangsutawika committed
146
        )
lintangsutawika's avatar
lintangsutawika committed
147
148
149
        if gen_kwargs == "":
            gen_kwargs = None

150
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
151
        if model_args is None:
152
            eval_logger.warning("model_args not specified. Using defaults.")
Fabrizio Milo's avatar
Fabrizio Milo committed
153
            model_args = ""
154
155
156
157
158
159
160
161
162
        if "pretrained" not in model_args and model in [
            "hf-auto",
            "hf",
            "huggingface",
            "vllm",
        ]:
            eval_logger.warning(
                "pretrained not specified. Using default pretrained=gpt2."
            )
163

164
        if isinstance(model_args, dict):
165
166
167
            eval_logger.info(
                f"Initializing {model} model, with arguments: {model_args}"
            )
168
169
170
171
172
173
174
175
176
177
            lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
                model_args,
                {
                    "batch_size": batch_size,
                    "max_batch_size": max_batch_size,
                    "device": device,
                },
            )

        else:
178
179
180
            eval_logger.info(
                f"Initializing {model} model, with arguments: {simple_parse_args_string(model_args)}"
            )
181
182
183
184
185
186
187
188
            lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
                model_args,
                {
                    "batch_size": batch_size,
                    "max_batch_size": max_batch_size,
                    "device": device,
                },
            )
189
    else:
190
191
        if not isinstance(model, lm_eval.api.model.LM):
            raise TypeError
192
        eval_logger.info("Using pre-initialized model")
193
        lm = model
194

haileyschoelkopf's avatar
haileyschoelkopf committed
195
    if use_cache is not None:
196
        eval_logger.info(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
haileyschoelkopf's avatar
haileyschoelkopf committed
197
198
199
200
201
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
202
203
204
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
205
206
        )

207
208
209
    if check_integrity:
        run_task_tests(task_list=tasks)

210
211
212
213
    if task_manager is None:
        task_manager = TaskManager(verbosity)

    task_dict = get_task_dict(tasks, task_manager)
214
    def _adjust_config(task_dict):
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
215

216
217
218
219
220
221
222
        adjusted_task_dict = {}
        for task_name, task_obj in task_dict.items():
            if isinstance(task_obj, dict):
                adjusted_task_dict = {
                    **adjusted_task_dict,
                    **{task_name: _adjust_config(task_obj)}
                }
Stephen Hogg's avatar
Stephen Hogg committed
223

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
            else:
                if task_obj.get_config("output_type") == "generate_until":
                    if gen_kwargs is not None:
                        task_obj.set_config(
                            key="generation_kwargs", value=gen_kwargs, update=True
                        )

                if predict_only:
                    log_samples = True
                    eval_logger.info(
                        f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
                    )
                    # we have to change the class properties post-hoc. This is pretty hacky.
                    task_obj.override_metric(metric_name="bypass")

                # override tasks' fewshot values to the provided num_fewshot arg value
                # except if tasks have it set to 0 manually in their configs--then we should never overwrite that
                if num_fewshot is not None:
                    if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
                        eval_logger.info(
                            f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                        )
                    else:
                        eval_logger.warning(
                            f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                        )
                        task_obj.set_config(key="num_fewshot", value=num_fewshot)
                else:
                    # if num_fewshot not provided, and the task does not define a default one, default to 0
                    if (default_num_fewshot := task_obj.get_config("num_fewshot")) is None:
                        task_obj.set_config(key="num_fewshot", value=0)
                
                adjusted_task_dict[task_name] = task_obj

        return adjusted_task_dict

    task_dict = _adjust_config(task_dict)
261
262
263
264
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
265
266
        cache_requests=cache_requests,
        rewrite_requests_cache=rewrite_requests_cache,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
267
        bootstrap_iters=bootstrap_iters,
268
        write_out=write_out,
269
        log_samples=log_samples,
270
        verbosity=verbosity,
271
    )
272

273
    if lm.rank == 0:
274
275
276
277
278
279
280
        if isinstance(model, str):
            model_name = model
        elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
            model_name = model.config._name_or_path
        else:
            model_name = type(model).__name__

281
282
        # add info about the model and few shot config
        results["config"] = {
283
            "model": model_name,
284
285
            "model_args": model_args,
            "batch_size": batch_size,
286
287
288
            "batch_sizes": (
                list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
            ),
289
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
290
            "use_cache": use_cache,
291
292
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
lintangsutawika's avatar
lintangsutawika committed
293
            "gen_kwargs": gen_kwargs,
294
        }
295
        results["git_hash"] = get_git_commit_hash()
296
        results["date"] = start_date
297
        add_env_info(results)  # additional environment info to results
298
299
300
        return results
    else:
        return None
301

Leo Gao's avatar
Leo Gao committed
302

303
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
304
def evaluate(
305
    lm: "LM",
Fabrizio Milo's avatar
Fabrizio Milo committed
306
    task_dict,
Baber Abbasi's avatar
Baber Abbasi committed
307
    limit: Optional[int] = None,
308
309
    cache_requests: bool = False,
    rewrite_requests_cache: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
310
    bootstrap_iters: Optional[int] = 100000,
Ethan Smith's avatar
Ethan Smith committed
311
312
    write_out: bool = False,
    log_samples: bool = True,
313
    verbosity: str = "INFO",
Fabrizio Milo's avatar
Fabrizio Milo committed
314
):
315
316
317
318
319
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
320
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
321
322
323
324
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
325
    :param write_out: bool
326
327
328
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
329
330
331
    :return
        Dictionary of results
    """
332

333
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
334

335
    # tracks all Instances/requests a model must generate output on.
336
    requests = defaultdict(list)
337
338
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
339
    padding_requests = defaultdict(int)
340

341
    # get lists of group hierarchy and each type of request
342
343
344
345
    eval_tasks = get_task_list(task_dict)
    # print("task_hierarchy")
    # print(task_hierarchy)
    # import sys; sys.exit()
346
    if not log_samples:
347
        if not all(
348
349
            "bypass" not in getattr(task_output.task, "_metric_fn_list", {}).keys()
            for task_output in eval_tasks
350
351
        ):
            raise ValueError("log_samples must be True for 'bypass' metric-only tasks")
352
353
354
    for task_output in eval_tasks:
        task: Task = task_output.task
        limit = get_sample_size(task, limit)
355
356
357
358
359
360
361
        task.build_all_requests(
            limit=limit,
            rank=lm.rank,
            world_size=lm.world_size,
            cache_requests=cache_requests,
            rewrite_requests_cache=rewrite_requests_cache,
        )
362
        eval_logger.debug(
363
            f"Task: {task_output.task_name}; number of requests on this rank: {len(task.instances)}"
haileyschoelkopf's avatar
haileyschoelkopf committed
364
365
366
        )

        if write_out:
367
            print_writeout(task)
368
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
369
370
371
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
372
373

        if lm.world_size > 1:
374
375
376
377
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
378
379
380
381
382
383
            # "multiple_choice" task types dispatch (several) "loglikelihood" request types
            reqtype = (
                "loglikelihood"
                if task.OUTPUT_TYPE == "multiple_choice"
                else task.OUTPUT_TYPE
            )
384
            # compute number of pseudo-batches to pad with (FSDP/DDP require even batches among ranks)
385
            numpad = max(gathered_item) - gathered_item[lm.rank]
386
387
            # todo: may not account for padding in cases like SquadV2 which has multiple req types
            padding_requests[reqtype] += numpad
388

389
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
390
391
    # execute each type of request
    for reqtype, reqs in requests.items():
392
        eval_logger.info(f"Running {reqtype} requests")
393
394
395
396
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
397

398
399
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
400
401
                cloned_reqs.extend([req] * req.repeats)

402
403
404
405
406
407
408
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

409
410
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
411

412
413
    RANK = lm.rank
    WORLD_SIZE = lm.world_size
414
415
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
416
417
    for task_output in eval_tasks:
        task = task_output.task
418
419
        task.apply_filters()

420
421
        ### Collect values of metrics on all datapoints ###
        # # unpack results and sort back in order and return control to Task
haileyschoelkopf's avatar
haileyschoelkopf committed
422
        # TODO: make it possible to use a different metric per filter
423
        # Pre-process task.instances to group by doc_id
424
        instances_by_doc_id = defaultdict(list)
425
426
427
428
429
        for instance in task.instances:
            instances_by_doc_id[instance.doc_id].append(instance)
        # Sort instances within each group
        for instances in instances_by_doc_id.values():
            instances.sort(key=lambda x: x.idx)
haileyschoelkopf's avatar
haileyschoelkopf committed
430
        # iterate over different filters used
431
432
433
        for filter_key in task.instances[0].filtered_resps.keys():
            doc_iterator = task.doc_iterator(
                rank=RANK, limit=limit, world_size=WORLD_SIZE
434
            )
435
            for doc_id, doc in doc_iterator:
436
                requests = instances_by_doc_id[doc_id]
lintangsutawika's avatar
lintangsutawika committed
437
                metrics = task.process_results(
438
                    doc, [req.filtered_resps[filter_key] for req in requests]
lintangsutawika's avatar
lintangsutawika committed
439
                )
440
441
442
443
444
445
446
447
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
448
449
450
                        "filtered_resps": [
                            req.filtered_resps[filter_key] for req in requests
                        ],
451
452
                    }
                    example.update(metrics)
453
                    task_output.logged_samples.append(example)
454
                for metric, value in metrics.items():
455
                    task_output.sample_metrics[(metric, filter_key)].append(value)
456

457
458
    if WORLD_SIZE > 1:
        # if multigpu, then gather data across all ranks to rank 0
459
        # first gather logged samples across all ranks
460
461
462
463
464
465
466
467
        for task_output in eval_tasks:
            if log_samples:
                # for task_name, task_samples in list(samples.items()):
                full_samples = [None] * WORLD_SIZE if RANK == 0 else None
                torch.distributed.gather_object(
                    obj=task_output.logged_samples,
                    object_gather_list=full_samples,
                    dst=0,
468
                )
469

470
471
472
473
                if RANK == 0:
                    task_output.logged_samples = list(
                        itertools.chain.from_iterable(full_samples)
                    )
474

475
476
477
478
479
480
481
            # then collect metrics across all ranks
            for metrics in task_output.sample_metrics:
                metric_list = [None] * WORLD_SIZE if RANK == 0 else None
                torch.distributed.gather_object(
                    obj=task_output.sample_metrics[metrics],
                    object_gather_list=metric_list,
                    dst=0,
482
                )
483
484
485
486
                if RANK == 0:
                    task_output.sample_metrics[metrics] = list(
                        itertools.chain.from_iterable(metric_list)
                    )
487

488
    if RANK == 0:
489
490
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
491
492
493
494
495
        for task_output in eval_tasks:
            task_output.calculate_aggregate_metric(bootstrap_iters=bootstrap_iters)
        results, samples, configs, versions, num_fewshot = consolidate_results(
            eval_tasks
        )
Fabrizio Milo's avatar
Fabrizio Milo committed
496

497
        ### Calculate group metrics ###
lintangsutawika's avatar
lintangsutawika committed
498
        if bool(results):
499
            def process_group(results, task_dict, task_root=None, task_hierarchy=None, show_group_table=False):
500

501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
                if task_root is None:
                    task_root = {}

                if task_hierarchy is None:
                    task_hierarchy = {}

                for group_or_task, group_or_task_info in task_dict.items():
                    if isinstance(group_or_task_info, ConfigurableTask):
                        if task_root:
                            task_hierarchy.setdefault(task_root, []).append(group_or_task)
                    else:
                        results, _task_hierarchy, show_group_table = process_group(results, group_or_task_info, group_or_task, task_hierarchy, show_group_table)
                        if task_root:
                            task_hierarchy.setdefault(task_root, []).extend(task_hierarchy.get(group_or_task, []))
                        
                        if isinstance(group_or_task, ConfigurableGroup):
                            group_config = group_or_task.config
                            group = group_or_task.group
                            show_group_table = show_group_table | group_config["aggregate_metric"]
                            if group_config["aggregate_metric"] is False:
                                results[group][" "] = " "
                                continue
                            
                        elif isinstance(group_or_task, str):
                            results[group_or_task][" "] = " "
                            continue

                        task_list = _task_hierarchy[group_or_task]
                        metric_list = list(
                            {
                                key
                                for task in task_list
                                for key in results[task].keys()
                                if "_stderr" not in key and key not in ["alias", "samples"]
                            }
                        )
                        for metric in metric_list:
                            stderr = "_stderr,".join(metric.split(","))

                            # gather metrics, sizes, and stderrs from subtasks
                            metrics = [
                                results[task][metric]
                                for task in task_list
                                if metric in results[task]
                            ]  # TODO: copy?
                            stderrs = [
                                results[task][stderr]
                                for task in task_list
                                if stderr in results[task]
                            ]
                            sizes = [
                                results[task]["samples"]
                                for task in task_list
                                if metric in results[task]
                            ]

                            # compute group's pooled metric and stderr
                            results[group][
                                metric
                            ] = lm_eval.api.metrics.aggregate_subtask_metrics(
                                metrics,
                                sizes,
                                group_config["weight_by_size"],
                            )
                            # TODO: calculate grouped metric using aggregation fn
                            if "N/A" in stderrs:
                                results[group][stderr] = "N/A"
                            else:
                                results[group][
                                    stderr
                                ] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
                                # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
                                # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
                                # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)

                            results[group]["samples"] = sum(sizes)
                return results, task_hierarchy, show_group_table

            results, task_hierarchy, show_group_table = process_group(results, task_dict)
lintangsutawika's avatar
lintangsutawika committed
580

581
582
        results_agg = defaultdict(dict)
        groups_agg = defaultdict(dict)
Lintang Sutawika's avatar
Lintang Sutawika committed
583
584
585
586
587
588
589
590
        all_tasks_list = list(task_hierarchy.keys())
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
591
                k: v["tasks"] for k, v in task_hierarchy.items() if k in left_tasks_list
Lintang Sutawika's avatar
Lintang Sutawika committed
592
            }
593
            _results_agg, _groups_agg = prepare_print_tasks(_task_hierarchy, results)
Lintang Sutawika's avatar
Lintang Sutawika committed
594
595
596

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
597

598
599
        for group_name, group_info in task_hierarchy.items():
            task_list = group_info["tasks"]
Baber Abbasi's avatar
Baber Abbasi committed
600
601
602
603
            if task_list:
                num_fewshot[group_name] = num_fewshot[
                    task_list[0]
                ]  # TODO: validate this
604

605
        import sys; sys.exit()
606
        results_dict = {
607
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
608
609
610
611
612
            **(
                {"groups": dict(groups_agg.items())}
                if (bool(groups_agg) & show_group_table)
                else {}
            ),
613
            "group_subtasks": dict(reversed(task_hierarchy.items())),
614
615
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
616
            "n-shot": dict(sorted(num_fewshot.items())),
617
        }
618
619
620
621
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
622

623
624
    else:
        return None
625
626
627
628


def request_caching_arg_to_dict(cache_requests: str) -> dict:
    request_caching_args = {
629
630
631
        "cache_requests": cache_requests in {"true", "refresh"},
        "rewrite_requests_cache": cache_requests == "refresh",
        "delete_requests_cache": cache_requests == "delete",
632
633
634
    }

    return request_caching_args