evaluator.py 25.7 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import collections
Baber Abbasi's avatar
Baber Abbasi committed
2
import itertools
3
import logging
Baber Abbasi's avatar
Baber Abbasi committed
4
5
6
import random
from typing import Optional, Union

7
import numpy as np
Baber Abbasi's avatar
Baber Abbasi committed
8
import torch
lintangsutawika's avatar
lintangsutawika committed
9

lintangsutawika's avatar
lintangsutawika committed
10
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
11
import lm_eval.api.registry
Baber Abbasi's avatar
Baber Abbasi committed
12
import lm_eval.models
13
from lm_eval.logging_utils import add_env_info, get_git_commit_hash
Baber Abbasi's avatar
Baber Abbasi committed
14
from lm_eval.tasks import TaskManager, get_task_dict
lintangsutawika's avatar
lintangsutawika committed
15
from lm_eval.utils import (
Baber Abbasi's avatar
Baber Abbasi committed
16
    eval_logger,
lintangsutawika's avatar
lintangsutawika committed
17
18
    positional_deprecated,
    run_task_tests,
lintangsutawika's avatar
lintangsutawika committed
19
    simple_parse_args_string,
lintangsutawika's avatar
lintangsutawika committed
20
)
21

Fabrizio Milo's avatar
Fabrizio Milo committed
22

23
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
24
25
def simple_evaluate(
    model,
Baber Abbasi's avatar
Baber Abbasi committed
26
    model_args: Optional[str] = None,
27
    tasks=None,
Baber Abbasi's avatar
Baber Abbasi committed
28
29
30
31
32
33
    num_fewshot: Optional[int] = None,
    batch_size: Optional[int] = None,
    max_batch_size: Optional[int] = None,
    device: Optional[str] = None,
    use_cache: Optional[str] = None,
    limit: Optional[Union[int, float]] = None,
Ethan Smith's avatar
Ethan Smith committed
34
35
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
36
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
37
38
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
39
    gen_kwargs: str = None,
40
41
    task_manager: TaskManager = None,
    verbosity: str = "INFO",
Baber Abbasi's avatar
Baber Abbasi committed
42
    predict_only: bool = False,
43
44
45
    random_seed: int = 0,
    numpy_random_seed: int = 1234,
    torch_random_seed: int = 1234,
Fabrizio Milo's avatar
Fabrizio Milo committed
46
):
47
    """Instantiate and evaluate a model on a list of tasks.
48

49
50
51
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
52
        String arguments for each model class, see LM.create_from_arg_string.
53
        Ignored if `model` argument is a LM object.
54
    :param tasks: list[Union[str, dict, Task]]
Leo Gao's avatar
Leo Gao committed
55
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
56
57
    :param num_fewshot: int
        Number of examples in few-shot context
58
    :param batch_size: int or str, optional
59
        Batch size for model
60
61
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
62
    :param device: str, optional
63
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
64
65
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
66
67
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
68
69
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
70
71
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
72
    :param write_out: bool
73
74
75
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
76
77
78
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
Baber Abbasi's avatar
Baber Abbasi committed
79
80
    :param predict_only: bool
        If true only model outputs will be generated and returned. Metrics will not be evaluated
81
82
83
84
85
86
    :param random_seed: int
        Random seed for python's random module. If set to None, the seed will not be set.
    :param numpy_random_seed: int
        Random seed for numpy. If set to None, the seed will not be set.
    :param torch_random_seed: int
        Random seed for torch. If set to None, the seed will not be set.
Baber Abbasi's avatar
Baber Abbasi committed
87

88
    :return
89
        Dictionary of results
90
    """
91
92
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))

93
94
95
96
97
98
99
100
101
102
103
104
105
    if random_seed is not None:
        # See https://github.com/EleutherAI/lm-evaluation-harness/pull/1412
        eval_logger.info(f"Setting random seed to {random_seed}")
        random.seed(random_seed)

    if numpy_random_seed is not None:
        eval_logger.info(f"Setting numpy seed to {numpy_random_seed}")
        np.random.seed(numpy_random_seed)

    if torch_random_seed is not None:
        eval_logger.info(f"Setting torch manual seed to {torch_random_seed}")
        torch.manual_seed(torch_random_seed)

106
107
    if tasks is None:
        tasks = []
108
109
110
    assert (
        tasks != []
    ), "No tasks specified, or no tasks found. Please verify the task names."
111

lintangsutawika's avatar
lintangsutawika committed
112
113
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
114
        eval_logger.warning(
Baber Abbasi's avatar
Baber Abbasi committed
115
            "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!"
lintangsutawika's avatar
udate  
lintangsutawika committed
116
        )
lintangsutawika's avatar
lintangsutawika committed
117
118
119
        if gen_kwargs == "":
            gen_kwargs = None

120
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
121
122
        if model_args is None:
            model_args = ""
lintangsutawika's avatar
lintangsutawika committed
123
        lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
lintangsutawika's avatar
lintangsutawika committed
124
125
126
127
128
129
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
130
        )
131
    else:
132
        assert isinstance(model, lm_eval.api.model.LM)
133
        lm = model
134

haileyschoelkopf's avatar
haileyschoelkopf committed
135
136
137
138
139
140
141
    if use_cache is not None:
        print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
142
143
144
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
145
146
        )

147
148
149
150
151
    if task_manager is None:
        task_manager = TaskManager(verbosity)

    eval_logger.info(
        "get_task_dict has been updated to accept an optional argument, `task_manager`"
Baber Abbasi's avatar
Baber Abbasi committed
152
153
        "Read more here:https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage"
    )
154
    task_dict = get_task_dict(tasks, task_manager)
155
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
156
        task_obj = task_dict[task_name]
157
        if isinstance(task_obj, tuple):
158
            _, task_obj = task_obj
159
160
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
161

Baber Abbasi's avatar
Baber Abbasi committed
162
163
        if task_obj.get_config("output_type") == "generate_until":
            if gen_kwargs is not None:
Baber Abbasi's avatar
Baber Abbasi committed
164
                task_obj.set_config(
Baber Abbasi's avatar
Baber Abbasi committed
165
166
167
168
169
170
171
172
173
174
                    key="generation_kwargs", value=gen_kwargs, update=True
                )

            if predict_only:
                log_samples = True
                eval_logger.info(
                    f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
                )
                # we have to change the class properties post-hoc. This is pretty hacky.
                task_obj.override_metric(metric_name="bypass")
175

176
        if num_fewshot is not None:
Baber Abbasi's avatar
Baber Abbasi committed
177
            if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
178
179
180
                eval_logger.info(
                    f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                )
181
            else:
Baber Abbasi's avatar
Baber Abbasi committed
182
183
184
                eval_logger.warning(
                    f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                )
Baber Abbasi's avatar
Baber Abbasi committed
185
                task_obj.set_config(key="num_fewshot", value=num_fewshot)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
186

Stephen Hogg's avatar
Stephen Hogg committed
187
    if check_integrity:
188
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
189

190
191
192
193
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
194
        bootstrap_iters=bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
195
        decontamination_ngrams_path=decontamination_ngrams_path,
196
        write_out=write_out,
197
        log_samples=log_samples,
198
        verbosity=verbosity,
199
    )
200

201
    if lm.rank == 0:
202
203
204
205
206
207
208
        if isinstance(model, str):
            model_name = model
        elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
            model_name = model.config._name_or_path
        else:
            model_name = type(model).__name__

209
210
        # add info about the model and few shot config
        results["config"] = {
211
            "model": model_name,
212
213
            "model_args": model_args,
            "batch_size": batch_size,
lintangsutawika's avatar
lintangsutawika committed
214
215
216
            "batch_sizes": list(lm.batch_sizes.values())
            if hasattr(lm, "batch_sizes")
            else [],
217
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
218
            "use_cache": use_cache,
219
220
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
lintangsutawika's avatar
lintangsutawika committed
221
            "gen_kwargs": gen_kwargs,
222
        }
223
        results["git_hash"] = get_git_commit_hash()
224
        add_env_info(results)  # additional environment info to results
225
226
227
        return results
    else:
        return None
228

Leo Gao's avatar
Leo Gao committed
229

230
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
231

Fabrizio Milo's avatar
Fabrizio Milo committed
232

233
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
234
235
236
def evaluate(
    lm,
    task_dict,
Baber Abbasi's avatar
Baber Abbasi committed
237
238
    limit: Optional[int] = None,
    bootstrap_iters: Optional[int] = 100000,
Fabrizio Milo's avatar
Fabrizio Milo committed
239
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
240
241
    write_out: bool = False,
    log_samples: bool = True,
242
    verbosity: str = "INFO",
Fabrizio Milo's avatar
Fabrizio Milo committed
243
):
244
245
246
247
248
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
249
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
250
251
252
253
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
254
    :param write_out: bool
255
256
257
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
258
259
260
    :return
        Dictionary of results
    """
261

262
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
lintangsutawika's avatar
lintangsutawika committed
263
    # decontaminate = decontamination_ngrams_path is not None
264

Baber Abbasi's avatar
Baber Abbasi committed
265
266
267
268
269
270
271
272
    for task_name, task in task_dict.items():
        if isinstance(task, tuple):
            _, task = task
        if not log_samples:
            assert (
                "bypass" not in getattr(task, "_metric_fn_list", {}).keys()
            ), f"log_samples must be True for 'bypass' only tasks: {task_name}"

273
    # stores the final result for each task, for each metric/filter pair.
Leo Gao's avatar
Leo Gao committed
274
    results = collections.defaultdict(dict)
275
    # Tracks each task's version.
Leo Gao's avatar
Leo Gao committed
276
    versions = collections.defaultdict(dict)
277
    # Tracks the YAML configs of all chosen tasks.
278
    configs = collections.defaultdict(dict)
279
    # logs info about each document evaluated.
lintangsutawika's avatar
lintangsutawika committed
280
    samples = collections.defaultdict(list)
281
    # tracks all Instances/requests a model must generate output on.
Leo Gao's avatar
Leo Gao committed
282
    requests = collections.defaultdict(list)
283
    # Aggregated task scores presented with groups
284
    results_agg = collections.defaultdict(dict)
285
    # Aggregated groups scores only
lintangsutawika's avatar
lintangsutawika committed
286
    groups_agg = collections.defaultdict(dict)
287
288
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
289
    padding_requests = collections.defaultdict(int)
lintangsutawika's avatar
lintangsutawika committed
290
    # store the hierarchy to do proper ordering
lintangsutawika's avatar
lintangsutawika committed
291
    task_hierarchy = collections.defaultdict(list)
292
293
    # store num-fewshot value per task
    num_fewshot = collections.defaultdict(int)
294

295
    # get lists of each type of request
296
    for task_name, task in task_dict.items():
297
        if isinstance(task, tuple):
lintangsutawika's avatar
lintangsutawika committed
298
299
            group_name, task = task
            task_hierarchy[group_name].append(task_name)
300
            versions[group_name] = "N/A"
lintangsutawika's avatar
lintangsutawika committed
301

302
        else:
303
            group_name = None
lintangsutawika's avatar
lintangsutawika committed
304
305
306
307
            task_hierarchy[task_name] = []

        if task is None:
            continue
308

Leo Gao's avatar
Leo Gao committed
309
        versions[task_name] = task.VERSION
haileyschoelkopf's avatar
haileyschoelkopf committed
310
311
        configs[task_name] = dict(task.dump_config())

Baber Abbasi's avatar
Baber Abbasi committed
312
313
314
        # Number of few-shots for printing.
        if (n_shot := configs[task_name].get("num_fewshot")) == 0:
            n_shot = configs[task_name].get("metadata", {}).get("num_fewshot", 0)
315
316
        num_fewshot[task_name] = n_shot

lintangsutawika's avatar
lintangsutawika committed
317
        if "task_alias" in configs[task_name]:
Lintang Sutawika's avatar
Lintang Sutawika committed
318
            results[task_name]["alias"] = configs[task_name]["task_alias"]
lintangsutawika's avatar
lintangsutawika committed
319

lintangsutawika's avatar
format  
lintangsutawika committed
320
321
        if (
            ("group_alias" in configs[task_name])
Lintang Sutawika's avatar
Lintang Sutawika committed
322
            and (group_name not in results)
lintangsutawika's avatar
format  
lintangsutawika committed
323
            and (group_name is not None)
lintangsutawika's avatar
lintangsutawika committed
324
        ):
Lintang Sutawika's avatar
Lintang Sutawika committed
325
            results[group_name]["alias"] = configs[task_name]["group_alias"]
lintangsutawika's avatar
lintangsutawika committed
326

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
327
        if limit is not None:
328
329
330
331
332
333
            if task.has_test_docs():
                task_docs = task.test_docs()
            elif task.has_validation_docs():
                task_docs = task.validation_docs()
            else:
                raise RuntimeError("Task has neither test_docs nor validation_docs")
334
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
335

336
337
        task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)

338
        eval_logger.debug(
haileyschoelkopf's avatar
haileyschoelkopf committed
339
340
341
342
343
344
            f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
        )

        if write_out:
            for inst in task.instances:
                # print the prompt for the first few documents
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
345
346
                if inst.doc_id < 1:
                    eval_logger.info(
haileyschoelkopf's avatar
haileyschoelkopf committed
347
348
                        f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
haileyschoelkopf's avatar
haileyschoelkopf committed
349
                    )
haileyschoelkopf's avatar
haileyschoelkopf committed
350
                    eval_logger.info(f"Request: {str(inst)}")
haileyschoelkopf's avatar
haileyschoelkopf committed
351

352
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
353
354
355
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
356
357

        if lm.world_size > 1:
358
359
360
361
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
362

363
            # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
364
            numpad = max(gathered_item) - gathered_item[lm.rank]
365
            padding_requests[task.OUTPUT_TYPE] += numpad
366

367
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
368
369
    # execute each type of request
    for reqtype, reqs in requests.items():
370
        eval_logger.info(f"Running {reqtype} requests")
371
372
373
374
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
375

376
377
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
378
379
                cloned_reqs.extend([req] * req.repeats)

380
381
382
383
384
385
386
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

387
388
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
389

390
391
392
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
    for task_name, task in task_dict.items():
393
        if isinstance(task, tuple):
394
            group, task = task
395
396
            if task is None:
                continue
397
398
399
        task.apply_filters()

    ### Collect values of metrics on all datapoints ###
Leo Gao's avatar
Leo Gao committed
400
401
402
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
403
    for task_name, task in task_dict.items():
404
        if isinstance(task, tuple):
405
            group, task = task
406
407
            if task is None:
                continue
haileyschoelkopf's avatar
haileyschoelkopf committed
408
409
        # TODO: make it possible to use a different metric per filter
        # iterate over different filters used
410
        for key in task.instances[0].filtered_resps.keys():
411
412
413
414
            doc_iterator = (
                itertools.islice(
                    enumerate(task.test_docs()), lm.rank, limit, lm.world_size
                )
lintangsutawika's avatar
lintangsutawika committed
415
                if task.has_test_docs()
416
417
418
419
                else itertools.islice(
                    enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
                )
            )
420
            for doc_id, doc in doc_iterator:
421
422
                # subset instances to only this document id ; sort by idx
                requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
423
                requests.sort(key=lambda x: x.idx)
lintangsutawika's avatar
lintangsutawika committed
424
425
426
                metrics = task.process_results(
                    doc, [req.filtered_resps[key] for req in requests]
                )
427
428
429
430
431
432
433
434
435
436
437
438
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
                        "filtered_resps": [req.filtered_resps[key] for req in requests],
                    }
                    example.update(metrics)
                    samples[task_name].append(example)
439
440
441
                for metric, value in metrics.items():
                    vals[(task_name, key, metric)].append(value)

442
    if lm.world_size > 1:
443
        # if multigpu, then gather data across all ranks
444
445
446
447
448
449
450
451
        # first gather logged samples across all ranks
        for task_name, task_samples in list(samples.items()):
            full_samples = [None] * lm.world_size
            torch.distributed.all_gather_object(full_samples, task_samples)

            samples[task_name] = list(itertools.chain.from_iterable(full_samples))

        # then collect metrics across all ranks
452
453
        vals_torch = collections.defaultdict(list)
        for (task_name, key, metric), items in vals.items():
454
            numitem = 0
455
            if isinstance(items[0], tuple):
456
457
                numitem = len(items[0])

Lintang Sutawika's avatar
Lintang Sutawika committed
458
            if isinstance(items[0], (str, list, tuple)):
459
460
461
                # handle the string case
                gathered_items = [None] * lm.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_items, items)
462

463
                gathered_item = list(itertools.chain.from_iterable(gathered_items))
464
            else:
465
466
467
468
469
470
471
472
473
474
                # distributed gather requires all ranks to have same dimensions
                # so we pad out with float32 min value
                pad_value = torch.finfo(torch.float32).min
                metrics_tensor = torch.tensor(items, device=lm.device)

                original_dtype = metrics_tensor.dtype  # store original dtype
                torch_device_tensor = lm.accelerator.pad_across_processes(
                    metrics_tensor.to(torch.float32), pad_index=pad_value
                )
                gathered_item = lm.accelerator.gather(torch_device_tensor)
475

476
477
478
479
480
481
482
483
484
485
486
                if numitem > 0:
                    gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
                else:
                    gathered_filtered = gathered_item[gathered_item != pad_value]

                gathered_item = (
                    gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
                )
                # reconvert if we were passed a tuple of values
                if numitem > 0:
                    gathered_item = [tuple(g) for g in gathered_item]
487

488
489
            if lm.rank == 0:
                vals_torch[(task_name, key, metric)] = gathered_item
490

491
        vals = vals_torch
492

493
494
495
496
497
    if lm.rank == 0:
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
        for (task_name, key, metric), items in vals.items():
            task = task_dict[task_name]
Baber Abbasi's avatar
Baber Abbasi committed
498
            group_name, task = task if isinstance(task, tuple) else (None, task)
lintangsutawika's avatar
lintangsutawika committed
499

Baber Abbasi's avatar
Baber Abbasi committed
500
            metric_key = f"{metric},{key}"
501
            agg_fn = task.aggregation()[metric]
Baber Abbasi's avatar
Baber Abbasi committed
502

503
504
            results[task_name][metric_key] = agg_fn(items)
            results[task_name]["samples"] = len(items)
lintangsutawika's avatar
lintangsutawika committed
505

506
507
            # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
            # so we run them less iterations. still looking for a cleaner way to do this
haileyschoelkopf's avatar
haileyschoelkopf committed
508
            if bootstrap_iters > 0:
Baber Abbasi's avatar
Baber Abbasi committed
509
510
                stderr_fn = lm_eval.api.metrics.stderr_for_metric(
                    metric=agg_fn,
haileyschoelkopf's avatar
haileyschoelkopf committed
511
                    bootstrap_iters=min(bootstrap_iters, 100)
haileyschoelkopf's avatar
haileyschoelkopf committed
512
513
514
                    if metric in ["bleu", "chrf", "ter"]
                    else bootstrap_iters,
                )
515

Baber Abbasi's avatar
Baber Abbasi committed
516
517
518
                results[task_name][f"{metric}_stderr,{key}"] = (
                    stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
                )
Fabrizio Milo's avatar
Fabrizio Milo committed
519

lintangsutawika's avatar
lintangsutawika committed
520
        if bool(results):
521
            for group, task_list in reversed(task_hierarchy.items()):
522
523
524
525
526
527
528
                if len(task_list) == 0:
                    # task_hierarchy entries are either
                    # `group_name: [subtask1, subtask2, ...]`
                    # or `task_name: []`.
                    # we only want to operate on groups here.
                    continue
                for metric in [
Baber Abbasi's avatar
Baber Abbasi committed
529
530
531
532
                    key
                    for key in results[task_list[0]].keys()
                    if "_stderr" not in key and key not in ["alias", "samples"]
                ]:  # TODO: what if tasks don't all share the same metrics
533
534
535
                    stderr = "_stderr,".join(metric.split(","))

                    # gather metrics, sizes, and stderrs from subtasks
Baber Abbasi's avatar
Baber Abbasi committed
536
537
538
                    metrics = [
                        results[task][metric] for task in task_list
                    ]  # TODO: copy?
539
540
541
542
                    stderrs = [results[task][stderr] for task in task_list]
                    sizes = [results[task]["samples"] for task in task_list]

                    # compute group's pooled metric and stderr
Baber Abbasi's avatar
Baber Abbasi committed
543
544
545
                    results[group][
                        metric
                    ] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
546
547
548
549
                    # TODO: calculate grouped metric using aggregation fn
                    if "N/A" in stderrs:
                        results[group][stderr] = "N/A"
                    else:
Baber Abbasi's avatar
Baber Abbasi committed
550
551
552
                        results[group][
                            stderr
                        ] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
553
554
555
556
557
                        # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
                        # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
                        # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)

                    results[group]["samples"] = sum(sizes)
lintangsutawika's avatar
lintangsutawika committed
558

Lintang Sutawika's avatar
Lintang Sutawika committed
559
        def print_tasks(task_hierarchy, results, tab=0):
560
561
562
            results_agg = collections.defaultdict(dict)
            groups_agg = collections.defaultdict(dict)

Lintang Sutawika's avatar
Lintang Sutawika committed
563
564
            (group_name, task_list), *_ = task_hierarchy.items()
            task_list = sorted(task_list)
565

Lintang Sutawika's avatar
Lintang Sutawika committed
566
567
568
569
            results_agg[group_name] = results[group_name].copy()
            # results_agg[group_name]["tab"] = tab
            if "samples" in results_agg[group_name]:
                results_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
570

Lintang Sutawika's avatar
Lintang Sutawika committed
571
            tab_string = " " * tab + "- " if tab > 0 else ""
lintangsutawika's avatar
lintangsutawika committed
572

Lintang Sutawika's avatar
Lintang Sutawika committed
573
574
575
576
            if "alias" in results_agg[group_name]:
                results_agg[group_name]["alias"] = (
                    tab_string + results_agg[group_name]["alias"]
                )
lintangsutawika's avatar
lintangsutawika committed
577
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
578
                results_agg[group_name]["alias"] = tab_string + group_name
lintangsutawika's avatar
lintangsutawika committed
579

Lintang Sutawika's avatar
Lintang Sutawika committed
580
581
582
583
584
            if len(task_list) > 0:
                groups_agg[group_name] = results[group_name].copy()
                # groups_agg[group_name]["tab"] = tab
                if "samples" in groups_agg[group_name]:
                    groups_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
585

Lintang Sutawika's avatar
Lintang Sutawika committed
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
                if "alias" in groups_agg[group_name]:
                    groups_agg[group_name]["alias"] = (
                        tab_string + groups_agg[group_name]["alias"]
                    )
                else:
                    groups_agg[group_name]["alias"] = tab_string + group_name

                for task_name in task_list:
                    if task_name in task_hierarchy:
                        _task_hierarchy = {
                            **{task_name: task_hierarchy[task_name]},
                            **task_hierarchy,
                        }
                    else:
                        _task_hierarchy = {
                            **{task_name: []},
                            **task_hierarchy,
                        }

                    _results_agg, _groups_agg = print_tasks(
                        _task_hierarchy, results, tab + 1
                    )
                    results_agg = {**results_agg, **_results_agg}
                    groups_agg = {**groups_agg, **_groups_agg}

            return results_agg, groups_agg

        results_agg = collections.defaultdict(dict)
        groups_agg = collections.defaultdict(dict)
        all_tasks_list = list(task_hierarchy.keys())
        left_tasks_list = []
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
                k: v for k, v in task_hierarchy.items() if k in left_tasks_list
            }
            _results_agg, _groups_agg = print_tasks(_task_hierarchy, results)

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
630

631
        for group_name, task_list in task_hierarchy.items():
Baber Abbasi's avatar
Baber Abbasi committed
632
633
634
635
            if task_list:
                num_fewshot[group_name] = num_fewshot[
                    task_list[0]
                ]  # TODO: validate this
636

637
        results_dict = {
638
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
639
            **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
640
            "group_subtasks": {k: v for k, v in reversed(task_hierarchy.items())},
641
642
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
643
            "n-shot": dict(sorted(num_fewshot.items())),
644
        }
645
646
647
648
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
649

650
651
    else:
        return None