evaluator.py 27.8 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import collections
Baber Abbasi's avatar
Baber Abbasi committed
2
import itertools
3
import logging
4
import math
Baber Abbasi's avatar
Baber Abbasi committed
5
import random
6
from typing import TYPE_CHECKING, Optional, Union
Baber Abbasi's avatar
Baber Abbasi committed
7

8
import numpy as np
Baber Abbasi's avatar
Baber Abbasi committed
9
import torch
lintangsutawika's avatar
lintangsutawika committed
10

lintangsutawika's avatar
lintangsutawika committed
11
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.registry
Baber Abbasi's avatar
Baber Abbasi committed
13
import lm_eval.models
14
from lm_eval.logging_utils import add_env_info, get_git_commit_hash
Baber Abbasi's avatar
Baber Abbasi committed
15
from lm_eval.tasks import TaskManager, get_task_dict
lintangsutawika's avatar
lintangsutawika committed
16
from lm_eval.utils import (
Baber Abbasi's avatar
Baber Abbasi committed
17
    eval_logger,
lintangsutawika's avatar
lintangsutawika committed
18
19
    positional_deprecated,
    run_task_tests,
lintangsutawika's avatar
lintangsutawika committed
20
    simple_parse_args_string,
lintangsutawika's avatar
lintangsutawika committed
21
)
22

Fabrizio Milo's avatar
Fabrizio Milo committed
23

24
25
26
27
28
29
30
if TYPE_CHECKING:
    from lm_eval.api.model import LM
    from lm_eval.tasks import Task

from lm_eval.caching.cache import delete_cache


31
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
32
33
def simple_evaluate(
    model,
34
    model_args: Optional[Union[str, dict, None]] = None,
35
    tasks=None,
Baber Abbasi's avatar
Baber Abbasi committed
36
37
38
39
40
    num_fewshot: Optional[int] = None,
    batch_size: Optional[int] = None,
    max_batch_size: Optional[int] = None,
    device: Optional[str] = None,
    use_cache: Optional[str] = None,
41
42
43
    cache_requests: bool = False,
    rewrite_requests_cache: bool = False,
    delete_requests_cache: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
44
    limit: Optional[Union[int, float]] = None,
Ethan Smith's avatar
Ethan Smith committed
45
46
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
47
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
48
49
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
50
    gen_kwargs: str = None,
51
52
    task_manager: TaskManager = None,
    verbosity: str = "INFO",
Baber Abbasi's avatar
Baber Abbasi committed
53
    predict_only: bool = False,
54
55
56
    random_seed: int = 0,
    numpy_random_seed: int = 1234,
    torch_random_seed: int = 1234,
Fabrizio Milo's avatar
Fabrizio Milo committed
57
):
58
    """Instantiate and evaluate a model on a list of tasks.
59

60
61
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
62
63
    :param model_args: Optional[str, dict]
        String or dict arguments for each model class, see LM.create_from_arg_string and LM.create_from_arg_object.
64
        Ignored if `model` argument is a LM object.
65
    :param tasks: list[Union[str, dict, Task]]
Leo Gao's avatar
Leo Gao committed
66
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
67
68
    :param num_fewshot: int
        Number of examples in few-shot context
69
    :param batch_size: int or str, optional
70
        Batch size for model
71
72
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
73
    :param device: str, optional
74
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
75
76
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
77
78
79
80
81
82
    :param cache_requests: bool, optional
        Speed up evaluation by caching the building of dataset requests. `None` if not caching.
    :param rewrite_requests_cache: bool, optional
        Rewrites all of the request cache if set to `True`. `None` if not desired.
    :param delete_requests_cache: bool, optional
        Deletes all of the request cache if set to `True`. `None` if not desired.
83
84
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
85
86
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
87
88
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
89
    :param write_out: bool
90
91
92
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
93
94
95
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
Baber Abbasi's avatar
Baber Abbasi committed
96
97
    :param predict_only: bool
        If true only model outputs will be generated and returned. Metrics will not be evaluated
98
99
100
101
102
103
    :param random_seed: int
        Random seed for python's random module. If set to None, the seed will not be set.
    :param numpy_random_seed: int
        Random seed for numpy. If set to None, the seed will not be set.
    :param torch_random_seed: int
        Random seed for torch. If set to None, the seed will not be set.
Baber Abbasi's avatar
Baber Abbasi committed
104

105
    :return
106
        Dictionary of results
107
    """
108
109
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))

110
111
112
113
    if delete_requests_cache:
        eval_logger.info("Deleting requests cache...")
        delete_cache()

114
115
116
117
118
119
120
121
122
123
124
125
126
    if random_seed is not None:
        # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
        eval_logger.info(f"Setting random seed to {random_seed}")
        random.seed(random_seed)

    if numpy_random_seed is not None:
        eval_logger.info(f"Setting numpy seed to {numpy_random_seed}")
        np.random.seed(numpy_random_seed)

    if torch_random_seed is not None:
        eval_logger.info(f"Setting torch manual seed to {torch_random_seed}")
        torch.manual_seed(torch_random_seed)

127
128
    if tasks is None:
        tasks = []
129
130
131
    assert (
        tasks != []
    ), "No tasks specified, or no tasks found. Please verify the task names."
132

lintangsutawika's avatar
lintangsutawika committed
133
134
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
135
        eval_logger.warning(
Baber Abbasi's avatar
Baber Abbasi committed
136
            "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!"
lintangsutawika's avatar
udate  
lintangsutawika committed
137
        )
lintangsutawika's avatar
lintangsutawika committed
138
139
140
        if gen_kwargs == "":
            gen_kwargs = None

141
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
142
143
        if model_args is None:
            model_args = ""
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

        elif isinstance(model_args, dict):
            lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
                model_args,
                {
                    "batch_size": batch_size,
                    "max_batch_size": max_batch_size,
                    "device": device,
                },
            )

        else:
            lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
                model_args,
                {
                    "batch_size": batch_size,
                    "max_batch_size": max_batch_size,
                    "device": device,
                },
            )
164
    else:
165
        assert isinstance(model, lm_eval.api.model.LM)
166
        lm = model
167

haileyschoelkopf's avatar
haileyschoelkopf committed
168
169
170
171
172
173
174
    if use_cache is not None:
        print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
175
176
177
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
178
179
        )

180
181
182
183
184
    if task_manager is None:
        task_manager = TaskManager(verbosity)

    eval_logger.info(
        "get_task_dict has been updated to accept an optional argument, `task_manager`"
Baber Abbasi's avatar
Baber Abbasi committed
185
186
        "Read more here:https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage"
    )
187
    task_dict = get_task_dict(tasks, task_manager)
188
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
189
        task_obj = task_dict[task_name]
190
        if isinstance(task_obj, tuple):
191
            _, task_obj = task_obj
192
193
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
194

Baber Abbasi's avatar
Baber Abbasi committed
195
196
        if task_obj.get_config("output_type") == "generate_until":
            if gen_kwargs is not None:
Baber Abbasi's avatar
Baber Abbasi committed
197
                task_obj.set_config(
Baber Abbasi's avatar
Baber Abbasi committed
198
199
200
201
202
203
204
205
206
207
                    key="generation_kwargs", value=gen_kwargs, update=True
                )

            if predict_only:
                log_samples = True
                eval_logger.info(
                    f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
                )
                # we have to change the class properties post-hoc. This is pretty hacky.
                task_obj.override_metric(metric_name="bypass")
208

209
        if num_fewshot is not None:
Baber Abbasi's avatar
Baber Abbasi committed
210
            if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
211
212
213
                eval_logger.info(
                    f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                )
214
            else:
Baber Abbasi's avatar
Baber Abbasi committed
215
216
217
                eval_logger.warning(
                    f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                )
Baber Abbasi's avatar
Baber Abbasi committed
218
                task_obj.set_config(key="num_fewshot", value=num_fewshot)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
219

Stephen Hogg's avatar
Stephen Hogg committed
220
    if check_integrity:
221
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
222

223
224
225
226
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
227
228
        cache_requests=cache_requests,
        rewrite_requests_cache=rewrite_requests_cache,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
229
        bootstrap_iters=bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
230
        decontamination_ngrams_path=decontamination_ngrams_path,
231
        write_out=write_out,
232
        log_samples=log_samples,
233
        verbosity=verbosity,
234
    )
235

236
    if lm.rank == 0:
237
238
239
240
241
242
243
        if isinstance(model, str):
            model_name = model
        elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
            model_name = model.config._name_or_path
        else:
            model_name = type(model).__name__

244
245
        # add info about the model and few shot config
        results["config"] = {
246
            "model": model_name,
247
248
            "model_args": model_args,
            "batch_size": batch_size,
249
250
251
            "batch_sizes": (
                list(lm.batch_sizes.values()) if hasattr(lm, "batch_sizes") else []
            ),
252
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
253
            "use_cache": use_cache,
254
255
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
lintangsutawika's avatar
lintangsutawika committed
256
            "gen_kwargs": gen_kwargs,
257
        }
258
        results["git_hash"] = get_git_commit_hash()
259
        add_env_info(results)  # additional environment info to results
260
261
262
        return results
    else:
        return None
263

Leo Gao's avatar
Leo Gao committed
264

265
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
266

Fabrizio Milo's avatar
Fabrizio Milo committed
267

268
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
269
def evaluate(
270
    lm: "LM",
Fabrizio Milo's avatar
Fabrizio Milo committed
271
    task_dict,
Baber Abbasi's avatar
Baber Abbasi committed
272
    limit: Optional[int] = None,
273
274
    cache_requests=False,
    rewrite_requests_cache=False,
Baber Abbasi's avatar
Baber Abbasi committed
275
    bootstrap_iters: Optional[int] = 100000,
Fabrizio Milo's avatar
Fabrizio Milo committed
276
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
277
278
    write_out: bool = False,
    log_samples: bool = True,
279
    verbosity: str = "INFO",
Fabrizio Milo's avatar
Fabrizio Milo committed
280
):
281
282
283
284
285
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
286
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
287
288
289
290
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
291
    :param write_out: bool
292
293
294
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
295
296
297
    :return
        Dictionary of results
    """
298

299
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
lintangsutawika's avatar
lintangsutawika committed
300
    # decontaminate = decontamination_ngrams_path is not None
301

Baber Abbasi's avatar
Baber Abbasi committed
302
303
304
305
306
307
308
309
    for task_name, task in task_dict.items():
        if isinstance(task, tuple):
            _, task = task
        if not log_samples:
            assert (
                "bypass" not in getattr(task, "_metric_fn_list", {}).keys()
            ), f"log_samples must be True for 'bypass' only tasks: {task_name}"

310
    # stores the final result for each task, for each metric/filter pair.
Leo Gao's avatar
Leo Gao committed
311
    results = collections.defaultdict(dict)
312
    # Tracks each task's version.
Leo Gao's avatar
Leo Gao committed
313
    versions = collections.defaultdict(dict)
314
    # Tracks the YAML configs of all chosen tasks.
315
    configs = collections.defaultdict(dict)
316
    # logs info about each document evaluated.
lintangsutawika's avatar
lintangsutawika committed
317
    samples = collections.defaultdict(list)
318
    # tracks all Instances/requests a model must generate output on.
Leo Gao's avatar
Leo Gao committed
319
    requests = collections.defaultdict(list)
320
    # Aggregated task scores presented with groups
321
    results_agg = collections.defaultdict(dict)
322
    # Aggregated groups scores only
lintangsutawika's avatar
lintangsutawika committed
323
    groups_agg = collections.defaultdict(dict)
324
325
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
326
    padding_requests = collections.defaultdict(int)
lintangsutawika's avatar
lintangsutawika committed
327
    # store the hierarchy to do proper ordering
lintangsutawika's avatar
lintangsutawika committed
328
    task_hierarchy = collections.defaultdict(list)
329
330
    # store num-fewshot value per task
    num_fewshot = collections.defaultdict(int)
331

332
    # get lists of each type of request
333
    for task_name, task in task_dict.items():
334
335
        task: Task

336
        if isinstance(task, tuple):
lintangsutawika's avatar
lintangsutawika committed
337
338
            group_name, task = task
            task_hierarchy[group_name].append(task_name)
339
            versions[group_name] = "N/A"
lintangsutawika's avatar
lintangsutawika committed
340

341
        else:
342
            group_name = None
lintangsutawika's avatar
lintangsutawika committed
343
344
345
346
            task_hierarchy[task_name] = []

        if task is None:
            continue
347

Leo Gao's avatar
Leo Gao committed
348
        versions[task_name] = task.VERSION
haileyschoelkopf's avatar
haileyschoelkopf committed
349
350
        configs[task_name] = dict(task.dump_config())

Baber Abbasi's avatar
Baber Abbasi committed
351
352
353
        # Number of few-shots for printing.
        if (n_shot := configs[task_name].get("num_fewshot")) == 0:
            n_shot = configs[task_name].get("metadata", {}).get("num_fewshot", 0)
354
355
        num_fewshot[task_name] = n_shot

lintangsutawika's avatar
lintangsutawika committed
356
        if "task_alias" in configs[task_name]:
Lintang Sutawika's avatar
Lintang Sutawika committed
357
            results[task_name]["alias"] = configs[task_name]["task_alias"]
lintangsutawika's avatar
lintangsutawika committed
358

lintangsutawika's avatar
format  
lintangsutawika committed
359
360
        if (
            ("group_alias" in configs[task_name])
Lintang Sutawika's avatar
Lintang Sutawika committed
361
            and (group_name not in results)
lintangsutawika's avatar
format  
lintangsutawika committed
362
            and (group_name is not None)
lintangsutawika's avatar
lintangsutawika committed
363
        ):
Lintang Sutawika's avatar
Lintang Sutawika committed
364
            results[group_name]["alias"] = configs[task_name]["group_alias"]
lintangsutawika's avatar
lintangsutawika committed
365

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
366
        if limit is not None:
367
368
369
370
371
372
            if task.has_test_docs():
                task_docs = task.test_docs()
            elif task.has_validation_docs():
                task_docs = task.validation_docs()
            else:
                raise RuntimeError("Task has neither test_docs nor validation_docs")
373

374
375
376
377
378
379
380
381
382
383
384
            num_docs = len(task_docs) * limit
            # ceil to prevent limit being equal to 0
            limit = int(math.ceil(num_docs)) if limit < 1.0 else int(limit)

        task.build_all_requests(
            limit=limit,
            rank=lm.rank,
            world_size=lm.world_size,
            cache_requests=cache_requests,
            rewrite_requests_cache=rewrite_requests_cache,
        )
385

386
        eval_logger.debug(
haileyschoelkopf's avatar
haileyschoelkopf committed
387
388
389
390
391
392
            f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
        )

        if write_out:
            for inst in task.instances:
                # print the prompt for the first few documents
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
393
394
                if inst.doc_id < 1:
                    eval_logger.info(
haileyschoelkopf's avatar
haileyschoelkopf committed
395
396
                        f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
haileyschoelkopf's avatar
haileyschoelkopf committed
397
                    )
haileyschoelkopf's avatar
haileyschoelkopf committed
398
                    eval_logger.info(f"Request: {str(inst)}")
haileyschoelkopf's avatar
haileyschoelkopf committed
399

400
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
401
402
403
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
404
405

        if lm.world_size > 1:
406
407
408
409
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
410

411
            # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
412
            numpad = max(gathered_item) - gathered_item[lm.rank]
413
            padding_requests[task.OUTPUT_TYPE] += numpad
414

415
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
416
417
    # execute each type of request
    for reqtype, reqs in requests.items():
418
        eval_logger.info(f"Running {reqtype} requests")
419
420
421
422
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
423

424
425
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
426
427
                cloned_reqs.extend([req] * req.repeats)

428
429
430
431
432
433
434
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

435
436
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
437

438
439
440
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
    for task_name, task in task_dict.items():
441
        if isinstance(task, tuple):
442
            group, task = task
443
444
            if task is None:
                continue
445
446
447
        task.apply_filters()

    ### Collect values of metrics on all datapoints ###
Leo Gao's avatar
Leo Gao committed
448
449
450
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
451
    for task_name, task in task_dict.items():
452
        if isinstance(task, tuple):
453
            group, task = task
454
455
            if task is None:
                continue
haileyschoelkopf's avatar
haileyschoelkopf committed
456
457
        # TODO: make it possible to use a different metric per filter
        # iterate over different filters used
458
        for key in task.instances[0].filtered_resps.keys():
459
460
461
462
            doc_iterator = (
                itertools.islice(
                    enumerate(task.test_docs()), lm.rank, limit, lm.world_size
                )
lintangsutawika's avatar
lintangsutawika committed
463
                if task.has_test_docs()
464
465
466
467
                else itertools.islice(
                    enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
                )
            )
468
            for doc_id, doc in doc_iterator:
469
470
                # subset instances to only this document id ; sort by idx
                requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
471
                requests.sort(key=lambda x: x.idx)
lintangsutawika's avatar
lintangsutawika committed
472
473
474
                metrics = task.process_results(
                    doc, [req.filtered_resps[key] for req in requests]
                )
475
476
477
478
479
480
481
482
483
484
485
486
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
                        "filtered_resps": [req.filtered_resps[key] for req in requests],
                    }
                    example.update(metrics)
                    samples[task_name].append(example)
487
488
489
                for metric, value in metrics.items():
                    vals[(task_name, key, metric)].append(value)

490
    if lm.world_size > 1:
491
        # if multigpu, then gather data across all ranks
492
493
494
495
496
497
498
499
        # first gather logged samples across all ranks
        for task_name, task_samples in list(samples.items()):
            full_samples = [None] * lm.world_size
            torch.distributed.all_gather_object(full_samples, task_samples)

            samples[task_name] = list(itertools.chain.from_iterable(full_samples))

        # then collect metrics across all ranks
500
501
        vals_torch = collections.defaultdict(list)
        for (task_name, key, metric), items in vals.items():
502
            numitem = 0
503
            if isinstance(items[0], tuple):
504
505
                numitem = len(items[0])

Lintang Sutawika's avatar
Lintang Sutawika committed
506
            if isinstance(items[0], (str, list, tuple)):
507
508
509
                # handle the string case
                gathered_items = [None] * lm.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_items, items)
510

511
                gathered_item = list(itertools.chain.from_iterable(gathered_items))
512
            else:
513
514
515
516
517
518
519
520
521
522
                # distributed gather requires all ranks to have same dimensions
                # so we pad out with float32 min value
                pad_value = torch.finfo(torch.float32).min
                metrics_tensor = torch.tensor(items, device=lm.device)

                original_dtype = metrics_tensor.dtype  # store original dtype
                torch_device_tensor = lm.accelerator.pad_across_processes(
                    metrics_tensor.to(torch.float32), pad_index=pad_value
                )
                gathered_item = lm.accelerator.gather(torch_device_tensor)
523

524
525
526
527
528
529
530
531
532
533
534
                if numitem > 0:
                    gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
                else:
                    gathered_filtered = gathered_item[gathered_item != pad_value]

                gathered_item = (
                    gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
                )
                # reconvert if we were passed a tuple of values
                if numitem > 0:
                    gathered_item = [tuple(g) for g in gathered_item]
535

536
537
            if lm.rank == 0:
                vals_torch[(task_name, key, metric)] = gathered_item
538

539
        vals = vals_torch
540

541
542
543
544
545
    if lm.rank == 0:
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
        for (task_name, key, metric), items in vals.items():
            task = task_dict[task_name]
Baber Abbasi's avatar
Baber Abbasi committed
546
            group_name, task = task if isinstance(task, tuple) else (None, task)
lintangsutawika's avatar
lintangsutawika committed
547

Baber Abbasi's avatar
Baber Abbasi committed
548
            metric_key = f"{metric},{key}"
549
            agg_fn = task.aggregation()[metric]
Baber Abbasi's avatar
Baber Abbasi committed
550

551
552
            results[task_name][metric_key] = agg_fn(items)
            results[task_name]["samples"] = len(items)
lintangsutawika's avatar
lintangsutawika committed
553

554
555
            # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
            # so we run them less iterations. still looking for a cleaner way to do this
haileyschoelkopf's avatar
haileyschoelkopf committed
556
            if bootstrap_iters > 0:
Baber Abbasi's avatar
Baber Abbasi committed
557
558
                stderr_fn = lm_eval.api.metrics.stderr_for_metric(
                    metric=agg_fn,
559
560
561
562
563
                    bootstrap_iters=(
                        min(bootstrap_iters, 100)
                        if metric in ["bleu", "chrf", "ter"]
                        else bootstrap_iters
                    ),
haileyschoelkopf's avatar
haileyschoelkopf committed
564
                )
565

Baber Abbasi's avatar
Baber Abbasi committed
566
567
568
                results[task_name][f"{metric}_stderr,{key}"] = (
                    stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
                )
Fabrizio Milo's avatar
Fabrizio Milo committed
569

lintangsutawika's avatar
lintangsutawika committed
570
        if bool(results):
571
            for group, task_list in reversed(task_hierarchy.items()):
572
573
574
575
576
577
578
                if len(task_list) == 0:
                    # task_hierarchy entries are either
                    # `group_name: [subtask1, subtask2, ...]`
                    # or `task_name: []`.
                    # we only want to operate on groups here.
                    continue
                for metric in [
Baber Abbasi's avatar
Baber Abbasi committed
579
580
581
582
                    key
                    for key in results[task_list[0]].keys()
                    if "_stderr" not in key and key not in ["alias", "samples"]
                ]:  # TODO: what if tasks don't all share the same metrics
583
584
585
                    stderr = "_stderr,".join(metric.split(","))

                    # gather metrics, sizes, and stderrs from subtasks
Baber Abbasi's avatar
Baber Abbasi committed
586
587
588
                    metrics = [
                        results[task][metric] for task in task_list
                    ]  # TODO: copy?
589
590
591
592
                    stderrs = [results[task][stderr] for task in task_list]
                    sizes = [results[task]["samples"] for task in task_list]

                    # compute group's pooled metric and stderr
Baber Abbasi's avatar
Baber Abbasi committed
593
594
595
                    results[group][
                        metric
                    ] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
596
597
598
599
                    # TODO: calculate grouped metric using aggregation fn
                    if "N/A" in stderrs:
                        results[group][stderr] = "N/A"
                    else:
Baber Abbasi's avatar
Baber Abbasi committed
600
601
602
                        results[group][
                            stderr
                        ] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
603
604
605
606
607
                        # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
                        # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
                        # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)

                    results[group]["samples"] = sum(sizes)
lintangsutawika's avatar
lintangsutawika committed
608

Lintang Sutawika's avatar
Lintang Sutawika committed
609
        def print_tasks(task_hierarchy, results, tab=0):
610
611
612
            results_agg = collections.defaultdict(dict)
            groups_agg = collections.defaultdict(dict)

Lintang Sutawika's avatar
Lintang Sutawika committed
613
614
            (group_name, task_list), *_ = task_hierarchy.items()
            task_list = sorted(task_list)
615

Lintang Sutawika's avatar
Lintang Sutawika committed
616
617
618
619
            results_agg[group_name] = results[group_name].copy()
            # results_agg[group_name]["tab"] = tab
            if "samples" in results_agg[group_name]:
                results_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
620

Lintang Sutawika's avatar
Lintang Sutawika committed
621
            tab_string = " " * tab + "- " if tab > 0 else ""
lintangsutawika's avatar
lintangsutawika committed
622

Lintang Sutawika's avatar
Lintang Sutawika committed
623
624
625
626
            if "alias" in results_agg[group_name]:
                results_agg[group_name]["alias"] = (
                    tab_string + results_agg[group_name]["alias"]
                )
lintangsutawika's avatar
lintangsutawika committed
627
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
628
                results_agg[group_name]["alias"] = tab_string + group_name
lintangsutawika's avatar
lintangsutawika committed
629

Lintang Sutawika's avatar
Lintang Sutawika committed
630
631
632
633
634
            if len(task_list) > 0:
                groups_agg[group_name] = results[group_name].copy()
                # groups_agg[group_name]["tab"] = tab
                if "samples" in groups_agg[group_name]:
                    groups_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
635

Lintang Sutawika's avatar
Lintang Sutawika committed
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
                if "alias" in groups_agg[group_name]:
                    groups_agg[group_name]["alias"] = (
                        tab_string + groups_agg[group_name]["alias"]
                    )
                else:
                    groups_agg[group_name]["alias"] = tab_string + group_name

                for task_name in task_list:
                    if task_name in task_hierarchy:
                        _task_hierarchy = {
                            **{task_name: task_hierarchy[task_name]},
                            **task_hierarchy,
                        }
                    else:
                        _task_hierarchy = {
                            **{task_name: []},
                            **task_hierarchy,
                        }

                    _results_agg, _groups_agg = print_tasks(
                        _task_hierarchy, results, tab + 1
                    )
                    results_agg = {**results_agg, **_results_agg}
                    groups_agg = {**groups_agg, **_groups_agg}

            return results_agg, groups_agg

        results_agg = collections.defaultdict(dict)
        groups_agg = collections.defaultdict(dict)
        all_tasks_list = list(task_hierarchy.keys())
        left_tasks_list = []
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
                k: v for k, v in task_hierarchy.items() if k in left_tasks_list
            }
            _results_agg, _groups_agg = print_tasks(_task_hierarchy, results)

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
680

681
        for group_name, task_list in task_hierarchy.items():
Baber Abbasi's avatar
Baber Abbasi committed
682
683
684
685
            if task_list:
                num_fewshot[group_name] = num_fewshot[
                    task_list[0]
                ]  # TODO: validate this
686

687
        results_dict = {
688
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
689
            **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
690
            "group_subtasks": {k: v for k, v in reversed(task_hierarchy.items())},
691
692
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
693
            "n-shot": dict(sorted(num_fewshot.items())),
694
        }
695
696
697
698
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
699

700
701
    else:
        return None
702
703
704
705
706
707
708
709
710
711
712
713


def request_caching_arg_to_dict(cache_requests: str) -> dict:
    request_caching_args = {
        "cache_requests": (
            True if cache_requests == "true" or cache_requests == "refresh" else False
        ),
        "rewrite_requests_cache": True if cache_requests == "refresh" else False,
        "delete_requests_cache": True if cache_requests == "delete" else False,
    }

    return request_caching_args