evaluator.py 24.7 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import collections
Baber Abbasi's avatar
Baber Abbasi committed
2
import itertools
3
import logging
Baber Abbasi's avatar
Baber Abbasi committed
4
5
6
import random
from typing import Optional, Union

7
import numpy as np
Baber Abbasi's avatar
Baber Abbasi committed
8
import torch
lintangsutawika's avatar
lintangsutawika committed
9

lintangsutawika's avatar
lintangsutawika committed
10
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
11
import lm_eval.api.registry
Baber Abbasi's avatar
Baber Abbasi committed
12
13
import lm_eval.models
from lm_eval.tasks import TaskManager, get_task_dict
lintangsutawika's avatar
lintangsutawika committed
14
from lm_eval.utils import (
Baber Abbasi's avatar
Baber Abbasi committed
15
16
    eval_logger,
    get_git_commit_hash,
lintangsutawika's avatar
lintangsutawika committed
17
18
    positional_deprecated,
    run_task_tests,
lintangsutawika's avatar
lintangsutawika committed
19
    simple_parse_args_string,
lintangsutawika's avatar
lintangsutawika committed
20
)
21

Fabrizio Milo's avatar
Fabrizio Milo committed
22

23
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
24
25
def simple_evaluate(
    model,
Baber Abbasi's avatar
Baber Abbasi committed
26
    model_args: Optional[str] = None,
27
    tasks=None,
Baber Abbasi's avatar
Baber Abbasi committed
28
29
30
31
32
33
    num_fewshot: Optional[int] = None,
    batch_size: Optional[int] = None,
    max_batch_size: Optional[int] = None,
    device: Optional[str] = None,
    use_cache: Optional[str] = None,
    limit: Optional[Union[int, float]] = None,
Ethan Smith's avatar
Ethan Smith committed
34
35
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
36
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
37
38
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
39
    gen_kwargs: str = None,
40
41
    task_manager: TaskManager = None,
    verbosity: str = "INFO",
Baber Abbasi's avatar
Baber Abbasi committed
42
    predict_only: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
43
):
44
    """Instantiate and evaluate a model on a list of tasks.
45

46
47
48
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
49
        String arguments for each model class, see LM.create_from_arg_string.
50
        Ignored if `model` argument is a LM object.
51
    :param tasks: list[Union[str, dict, Task]]
Leo Gao's avatar
Leo Gao committed
52
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
53
54
    :param num_fewshot: int
        Number of examples in few-shot context
55
    :param batch_size: int or str, optional
56
        Batch size for model
57
58
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
59
    :param device: str, optional
60
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
61
62
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
63
64
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
65
66
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
67
68
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
69
    :param write_out: bool
70
71
72
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
73
74
75
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
Baber Abbasi's avatar
Baber Abbasi committed
76
77
78
    :param predict_only: bool
        If true only model outputs will be generated and returned. Metrics will not be evaluated

79
    :return
80
        Dictionary of results
81
    """
82
    random.seed(0)
83
    np.random.seed(1234)
84
85
86
    torch.manual_seed(
        1234
    )  # TODO: this may affect training runs that are run with evaluation mid-run.
87

88
89
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))

90
91
    if tasks is None:
        tasks = []
92
93
94
    assert (
        tasks != []
    ), "No tasks specified, or no tasks found. Please verify the task names."
95

lintangsutawika's avatar
lintangsutawika committed
96
97
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
98
        eval_logger.warning(
Baber Abbasi's avatar
Baber Abbasi committed
99
            "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!"
lintangsutawika's avatar
udate  
lintangsutawika committed
100
        )
lintangsutawika's avatar
lintangsutawika committed
101
102
103
        if gen_kwargs == "":
            gen_kwargs = None

104
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
105
106
        if model_args is None:
            model_args = ""
lintangsutawika's avatar
lintangsutawika committed
107
        lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
lintangsutawika's avatar
lintangsutawika committed
108
109
110
111
112
113
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
114
        )
115
    else:
116
        assert isinstance(model, lm_eval.api.model.LM)
117
        lm = model
118

haileyschoelkopf's avatar
haileyschoelkopf committed
119
120
121
122
123
124
125
    if use_cache is not None:
        print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
126
127
128
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
129
130
        )

131
132
133
134
135
    if task_manager is None:
        task_manager = TaskManager(verbosity)

    eval_logger.info(
        "get_task_dict has been updated to accept an optional argument, `task_manager`"
Baber Abbasi's avatar
Baber Abbasi committed
136
137
        "Read more here:https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage"
    )
138
    task_dict = get_task_dict(tasks, task_manager)
139
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
140
        task_obj = task_dict[task_name]
141
        if isinstance(task_obj, tuple):
142
            _, task_obj = task_obj
143
144
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
145

Baber Abbasi's avatar
Baber Abbasi committed
146
147
        if task_obj.get_config("output_type") == "generate_until":
            if gen_kwargs is not None:
Baber Abbasi's avatar
Baber Abbasi committed
148
                task_obj.set_config(
Baber Abbasi's avatar
Baber Abbasi committed
149
150
151
152
153
154
155
156
157
158
                    key="generation_kwargs", value=gen_kwargs, update=True
                )

            if predict_only:
                log_samples = True
                eval_logger.info(
                    f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
                )
                # we have to change the class properties post-hoc. This is pretty hacky.
                task_obj.override_metric(metric_name="bypass")
159

160
        if num_fewshot is not None:
Baber Abbasi's avatar
Baber Abbasi committed
161
            if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
162
163
164
                eval_logger.info(
                    f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                )
165
            else:
Baber Abbasi's avatar
Baber Abbasi committed
166
167
168
                eval_logger.warning(
                    f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                )
Baber Abbasi's avatar
Baber Abbasi committed
169
                task_obj.set_config(key="num_fewshot", value=num_fewshot)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
170

Stephen Hogg's avatar
Stephen Hogg committed
171
    if check_integrity:
172
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
173

174
175
176
177
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
178
        bootstrap_iters=bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
179
        decontamination_ngrams_path=decontamination_ngrams_path,
180
        write_out=write_out,
181
        log_samples=log_samples,
182
        verbosity=verbosity,
183
    )
184

185
    if lm.rank == 0:
186
187
188
189
190
191
192
        if isinstance(model, str):
            model_name = model
        elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
            model_name = model.config._name_or_path
        else:
            model_name = type(model).__name__

193
194
        # add info about the model and few shot config
        results["config"] = {
195
            "model": model_name,
196
197
            "model_args": model_args,
            "batch_size": batch_size,
lintangsutawika's avatar
lintangsutawika committed
198
199
200
            "batch_sizes": list(lm.batch_sizes.values())
            if hasattr(lm, "batch_sizes")
            else [],
201
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
202
            "use_cache": use_cache,
203
204
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
lintangsutawika's avatar
lintangsutawika committed
205
            "gen_kwargs": gen_kwargs,
206
        }
207
        results["git_hash"] = get_git_commit_hash()
208
209
210
        return results
    else:
        return None
211

Leo Gao's avatar
Leo Gao committed
212

213
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
214

Fabrizio Milo's avatar
Fabrizio Milo committed
215

216
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
217
218
219
def evaluate(
    lm,
    task_dict,
Baber Abbasi's avatar
Baber Abbasi committed
220
221
    limit: Optional[int] = None,
    bootstrap_iters: Optional[int] = 100000,
Fabrizio Milo's avatar
Fabrizio Milo committed
222
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
223
224
    write_out: bool = False,
    log_samples: bool = True,
225
    verbosity: str = "INFO",
Fabrizio Milo's avatar
Fabrizio Milo committed
226
):
227
228
229
230
231
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
232
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
233
234
235
236
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
237
    :param write_out: bool
238
239
240
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
241
242
243
    :return
        Dictionary of results
    """
244

245
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
lintangsutawika's avatar
lintangsutawika committed
246
    # decontaminate = decontamination_ngrams_path is not None
247

Baber Abbasi's avatar
Baber Abbasi committed
248
249
250
251
252
253
254
255
    for task_name, task in task_dict.items():
        if isinstance(task, tuple):
            _, task = task
        if not log_samples:
            assert (
                "bypass" not in getattr(task, "_metric_fn_list", {}).keys()
            ), f"log_samples must be True for 'bypass' only tasks: {task_name}"

256
    # stores the final result for each task, for each metric/filter pair.
Leo Gao's avatar
Leo Gao committed
257
    results = collections.defaultdict(dict)
258
    # Tracks each task's version.
Leo Gao's avatar
Leo Gao committed
259
    versions = collections.defaultdict(dict)
260
    # Tracks the YAML configs of all chosen tasks.
261
    configs = collections.defaultdict(dict)
262
    # logs info about each document evaluated.
lintangsutawika's avatar
lintangsutawika committed
263
    samples = collections.defaultdict(list)
264
    # tracks all Instances/requests a model must generate output on.
Leo Gao's avatar
Leo Gao committed
265
    requests = collections.defaultdict(list)
266
    # Aggregated task scores presented with groups
267
    results_agg = collections.defaultdict(dict)
268
    # Aggregated groups scores only
lintangsutawika's avatar
lintangsutawika committed
269
    groups_agg = collections.defaultdict(dict)
270
271
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
272
    padding_requests = collections.defaultdict(int)
lintangsutawika's avatar
lintangsutawika committed
273
    # store the hierarchy to do proper ordering
lintangsutawika's avatar
lintangsutawika committed
274
    task_hierarchy = collections.defaultdict(list)
275
276
    # store num-fewshot value per task
    num_fewshot = collections.defaultdict(int)
277

278
    # get lists of each type of request
279
    for task_name, task in task_dict.items():
280
        if isinstance(task, tuple):
lintangsutawika's avatar
lintangsutawika committed
281
282
            group_name, task = task
            task_hierarchy[group_name].append(task_name)
283
            versions[group_name] = "N/A"
lintangsutawika's avatar
lintangsutawika committed
284

285
        else:
286
            group_name = None
lintangsutawika's avatar
lintangsutawika committed
287
288
289
290
            task_hierarchy[task_name] = []

        if task is None:
            continue
291

Leo Gao's avatar
Leo Gao committed
292
        versions[task_name] = task.VERSION
haileyschoelkopf's avatar
haileyschoelkopf committed
293
294
        configs[task_name] = dict(task.dump_config())

Baber Abbasi's avatar
Baber Abbasi committed
295
296
297
        # Number of few-shots for printing.
        if (n_shot := configs[task_name].get("num_fewshot")) == 0:
            n_shot = configs[task_name].get("metadata", {}).get("num_fewshot", 0)
298
299
        num_fewshot[task_name] = n_shot

lintangsutawika's avatar
lintangsutawika committed
300
        if "task_alias" in configs[task_name]:
Lintang Sutawika's avatar
Lintang Sutawika committed
301
            results[task_name]["alias"] = configs[task_name]["task_alias"]
lintangsutawika's avatar
lintangsutawika committed
302

lintangsutawika's avatar
format  
lintangsutawika committed
303
304
        if (
            ("group_alias" in configs[task_name])
Lintang Sutawika's avatar
Lintang Sutawika committed
305
            and (group_name not in results)
lintangsutawika's avatar
format  
lintangsutawika committed
306
            and (group_name is not None)
lintangsutawika's avatar
lintangsutawika committed
307
        ):
Lintang Sutawika's avatar
Lintang Sutawika committed
308
            results[group_name]["alias"] = configs[task_name]["group_alias"]
lintangsutawika's avatar
lintangsutawika committed
309

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
310
        if limit is not None:
311
312
313
314
315
316
            if task.has_test_docs():
                task_docs = task.test_docs()
            elif task.has_validation_docs():
                task_docs = task.validation_docs()
            else:
                raise RuntimeError("Task has neither test_docs nor validation_docs")
317
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
318

319
320
        task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)

321
        eval_logger.debug(
haileyschoelkopf's avatar
haileyschoelkopf committed
322
323
324
325
326
327
            f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
        )

        if write_out:
            for inst in task.instances:
                # print the prompt for the first few documents
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
328
329
                if inst.doc_id < 1:
                    eval_logger.info(
haileyschoelkopf's avatar
haileyschoelkopf committed
330
331
                        f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
haileyschoelkopf's avatar
haileyschoelkopf committed
332
                    )
haileyschoelkopf's avatar
haileyschoelkopf committed
333
                    eval_logger.info(f"Request: {str(inst)}")
haileyschoelkopf's avatar
haileyschoelkopf committed
334

335
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
336
337
338
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
339
340

        if lm.world_size > 1:
341
342
343
344
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
345

346
            # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
347
            numpad = max(gathered_item) - gathered_item[lm.rank]
348
            padding_requests[task.OUTPUT_TYPE] += numpad
349

350
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
351
352
    # execute each type of request
    for reqtype, reqs in requests.items():
353
        eval_logger.info(f"Running {reqtype} requests")
354
355
356
357
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
358

359
360
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
361
362
                cloned_reqs.extend([req] * req.repeats)

363
364
365
366
367
368
369
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

370
371
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
372

373
374
375
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
    for task_name, task in task_dict.items():
376
        if isinstance(task, tuple):
377
            group, task = task
378
379
            if task is None:
                continue
380
381
382
        task.apply_filters()

    ### Collect values of metrics on all datapoints ###
Leo Gao's avatar
Leo Gao committed
383
384
385
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
386
    for task_name, task in task_dict.items():
387
        if isinstance(task, tuple):
388
            group, task = task
389
390
            if task is None:
                continue
haileyschoelkopf's avatar
haileyschoelkopf committed
391
392
        # TODO: make it possible to use a different metric per filter
        # iterate over different filters used
393
        for key in task.instances[0].filtered_resps.keys():
394
395
396
397
            doc_iterator = (
                itertools.islice(
                    enumerate(task.test_docs()), lm.rank, limit, lm.world_size
                )
lintangsutawika's avatar
lintangsutawika committed
398
                if task.has_test_docs()
399
400
401
402
                else itertools.islice(
                    enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
                )
            )
403
            for doc_id, doc in doc_iterator:
404
405
                # subset instances to only this document id ; sort by idx
                requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
406
                requests.sort(key=lambda x: x.idx)
lintangsutawika's avatar
lintangsutawika committed
407
408
409
                metrics = task.process_results(
                    doc, [req.filtered_resps[key] for req in requests]
                )
410
411
412
413
414
415
416
417
418
419
420
421
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
                        "filtered_resps": [req.filtered_resps[key] for req in requests],
                    }
                    example.update(metrics)
                    samples[task_name].append(example)
422
423
424
                for metric, value in metrics.items():
                    vals[(task_name, key, metric)].append(value)

425
    if lm.world_size > 1:
426
        # if multigpu, then gather data across all ranks
427
428
429
430
431
432
433
434
        # first gather logged samples across all ranks
        for task_name, task_samples in list(samples.items()):
            full_samples = [None] * lm.world_size
            torch.distributed.all_gather_object(full_samples, task_samples)

            samples[task_name] = list(itertools.chain.from_iterable(full_samples))

        # then collect metrics across all ranks
435
436
        vals_torch = collections.defaultdict(list)
        for (task_name, key, metric), items in vals.items():
437
            numitem = 0
438
            if isinstance(items[0], tuple):
439
440
                numitem = len(items[0])

Lintang Sutawika's avatar
Lintang Sutawika committed
441
            if isinstance(items[0], (str, list, tuple)):
442
443
444
                # handle the string case
                gathered_items = [None] * lm.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_items, items)
445

446
                gathered_item = list(itertools.chain.from_iterable(gathered_items))
447
            else:
448
449
450
451
452
453
454
455
456
457
                # distributed gather requires all ranks to have same dimensions
                # so we pad out with float32 min value
                pad_value = torch.finfo(torch.float32).min
                metrics_tensor = torch.tensor(items, device=lm.device)

                original_dtype = metrics_tensor.dtype  # store original dtype
                torch_device_tensor = lm.accelerator.pad_across_processes(
                    metrics_tensor.to(torch.float32), pad_index=pad_value
                )
                gathered_item = lm.accelerator.gather(torch_device_tensor)
458

459
460
461
462
463
464
465
466
467
468
469
                if numitem > 0:
                    gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
                else:
                    gathered_filtered = gathered_item[gathered_item != pad_value]

                gathered_item = (
                    gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
                )
                # reconvert if we were passed a tuple of values
                if numitem > 0:
                    gathered_item = [tuple(g) for g in gathered_item]
470

471
472
            if lm.rank == 0:
                vals_torch[(task_name, key, metric)] = gathered_item
473

474
        vals = vals_torch
475

476
477
478
479
480
    if lm.rank == 0:
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
        for (task_name, key, metric), items in vals.items():
            task = task_dict[task_name]
Baber Abbasi's avatar
Baber Abbasi committed
481
            group_name, task = task if isinstance(task, tuple) else (None, task)
lintangsutawika's avatar
lintangsutawika committed
482

Baber Abbasi's avatar
Baber Abbasi committed
483
            metric_key = f"{metric},{key}"
484
            agg_fn = task.aggregation()[metric]
Baber Abbasi's avatar
Baber Abbasi committed
485

486
487
            results[task_name][metric_key] = agg_fn(items)
            results[task_name]["samples"] = len(items)
lintangsutawika's avatar
lintangsutawika committed
488

489
490
            # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
            # so we run them less iterations. still looking for a cleaner way to do this
haileyschoelkopf's avatar
haileyschoelkopf committed
491
            if bootstrap_iters > 0:
Baber Abbasi's avatar
Baber Abbasi committed
492
493
                stderr_fn = lm_eval.api.metrics.stderr_for_metric(
                    metric=agg_fn,
haileyschoelkopf's avatar
haileyschoelkopf committed
494
                    bootstrap_iters=min(bootstrap_iters, 100)
haileyschoelkopf's avatar
haileyschoelkopf committed
495
496
497
                    if metric in ["bleu", "chrf", "ter"]
                    else bootstrap_iters,
                )
498

Baber Abbasi's avatar
Baber Abbasi committed
499
500
501
                results[task_name][f"{metric}_stderr,{key}"] = (
                    stderr_fn(items) if (stderr_fn and len(items) > 1) else "N/A"
                )
Fabrizio Milo's avatar
Fabrizio Milo committed
502

lintangsutawika's avatar
lintangsutawika committed
503
        if bool(results):
504
            for group, task_list in reversed(task_hierarchy.items()):
505
506
507
508
509
510
511
                if len(task_list) == 0:
                    # task_hierarchy entries are either
                    # `group_name: [subtask1, subtask2, ...]`
                    # or `task_name: []`.
                    # we only want to operate on groups here.
                    continue
                for metric in [
Baber Abbasi's avatar
Baber Abbasi committed
512
513
514
515
                    key
                    for key in results[task_list[0]].keys()
                    if "_stderr" not in key and key not in ["alias", "samples"]
                ]:  # TODO: what if tasks don't all share the same metrics
516
517
518
                    stderr = "_stderr,".join(metric.split(","))

                    # gather metrics, sizes, and stderrs from subtasks
Baber Abbasi's avatar
Baber Abbasi committed
519
520
521
                    metrics = [
                        results[task][metric] for task in task_list
                    ]  # TODO: copy?
522
523
524
525
                    stderrs = [results[task][stderr] for task in task_list]
                    sizes = [results[task]["samples"] for task in task_list]

                    # compute group's pooled metric and stderr
Baber Abbasi's avatar
Baber Abbasi committed
526
527
528
                    results[group][
                        metric
                    ] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
529
530
531
532
                    # TODO: calculate grouped metric using aggregation fn
                    if "N/A" in stderrs:
                        results[group][stderr] = "N/A"
                    else:
Baber Abbasi's avatar
Baber Abbasi committed
533
534
535
                        results[group][
                            stderr
                        ] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
536
537
538
539
540
                        # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
                        # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
                        # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)

                    results[group]["samples"] = sum(sizes)
lintangsutawika's avatar
lintangsutawika committed
541

Lintang Sutawika's avatar
Lintang Sutawika committed
542
        def print_tasks(task_hierarchy, results, tab=0):
543
544
545
            results_agg = collections.defaultdict(dict)
            groups_agg = collections.defaultdict(dict)

Lintang Sutawika's avatar
Lintang Sutawika committed
546
547
            (group_name, task_list), *_ = task_hierarchy.items()
            task_list = sorted(task_list)
548

Lintang Sutawika's avatar
Lintang Sutawika committed
549
550
551
552
            results_agg[group_name] = results[group_name].copy()
            # results_agg[group_name]["tab"] = tab
            if "samples" in results_agg[group_name]:
                results_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
553

Lintang Sutawika's avatar
Lintang Sutawika committed
554
            tab_string = " " * tab + "- " if tab > 0 else ""
lintangsutawika's avatar
lintangsutawika committed
555

Lintang Sutawika's avatar
Lintang Sutawika committed
556
557
558
559
            if "alias" in results_agg[group_name]:
                results_agg[group_name]["alias"] = (
                    tab_string + results_agg[group_name]["alias"]
                )
lintangsutawika's avatar
lintangsutawika committed
560
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
561
                results_agg[group_name]["alias"] = tab_string + group_name
lintangsutawika's avatar
lintangsutawika committed
562

Lintang Sutawika's avatar
Lintang Sutawika committed
563
564
565
566
567
            if len(task_list) > 0:
                groups_agg[group_name] = results[group_name].copy()
                # groups_agg[group_name]["tab"] = tab
                if "samples" in groups_agg[group_name]:
                    groups_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
568

Lintang Sutawika's avatar
Lintang Sutawika committed
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
                if "alias" in groups_agg[group_name]:
                    groups_agg[group_name]["alias"] = (
                        tab_string + groups_agg[group_name]["alias"]
                    )
                else:
                    groups_agg[group_name]["alias"] = tab_string + group_name

                for task_name in task_list:
                    if task_name in task_hierarchy:
                        _task_hierarchy = {
                            **{task_name: task_hierarchy[task_name]},
                            **task_hierarchy,
                        }
                    else:
                        _task_hierarchy = {
                            **{task_name: []},
                            **task_hierarchy,
                        }

                    _results_agg, _groups_agg = print_tasks(
                        _task_hierarchy, results, tab + 1
                    )
                    results_agg = {**results_agg, **_results_agg}
                    groups_agg = {**groups_agg, **_groups_agg}

            return results_agg, groups_agg

        results_agg = collections.defaultdict(dict)
        groups_agg = collections.defaultdict(dict)
        all_tasks_list = list(task_hierarchy.keys())
        left_tasks_list = []
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
                k: v for k, v in task_hierarchy.items() if k in left_tasks_list
            }
            _results_agg, _groups_agg = print_tasks(_task_hierarchy, results)

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
613

614
        for group_name, task_list in task_hierarchy.items():
Baber Abbasi's avatar
Baber Abbasi committed
615
616
617
618
            if task_list:
                num_fewshot[group_name] = num_fewshot[
                    task_list[0]
                ]  # TODO: validate this
619

620
        results_dict = {
621
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
622
            **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
623
624
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
625
            "n-shot": dict(sorted(num_fewshot.items())),
626
        }
627
628
629
630
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
631

632
633
    else:
        return None