evaluator.py 24.6 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import random
Leo Gao's avatar
Leo Gao committed
2
import itertools
lintangsutawika's avatar
lintangsutawika committed
3
4
import collections

5
6
import torch

7
import logging
8
import numpy as np
lintangsutawika's avatar
lintangsutawika committed
9
10
11

import lm_eval.api
import lm_eval.models
lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
13
import lm_eval.api.registry
lintangsutawika's avatar
lintangsutawika committed
14

15
16
17
18
from lm_eval.tasks import (
    get_task_dict,
    TaskManager
)
lintangsutawika's avatar
lintangsutawika committed
19
20
21
22
from lm_eval.utils import (
    positional_deprecated,
    run_task_tests,
    get_git_commit_hash,
lintangsutawika's avatar
lintangsutawika committed
23
    simple_parse_args_string,
24
    eval_logger
lintangsutawika's avatar
lintangsutawika committed
25
)
26

Fabrizio Milo's avatar
Fabrizio Milo committed
27

28
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
29
30
31
def simple_evaluate(
    model,
    model_args=None,
32
    tasks=None,
33
    num_fewshot=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
34
    batch_size=None,
35
    max_batch_size=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
36
    device=None,
haileyschoelkopf's avatar
haileyschoelkopf committed
37
    use_cache=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
38
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
39
40
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
41
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
42
43
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
44
    gen_kwargs: str = None,
45
46
    task_manager: TaskManager = None,
    verbosity: str = "INFO",
Baber Abbasi's avatar
Baber Abbasi committed
47
    predict_only: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
48
):
49
    """Instantiate and evaluate a model on a list of tasks.
50

51
52
53
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
54
        String arguments for each model class, see LM.create_from_arg_string.
55
        Ignored if `model` argument is a LM object.
56
    :param tasks: list[Union[str, dict, Task]]
Leo Gao's avatar
Leo Gao committed
57
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
58
59
    :param num_fewshot: int
        Number of examples in few-shot context
60
    :param batch_size: int or str, optional
61
        Batch size for model
62
63
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
64
    :param device: str, optional
65
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
66
67
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
68
69
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
70
71
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
72
73
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
74
    :param write_out: bool
75
76
77
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
78
79
80
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
Baber Abbasi's avatar
Baber Abbasi committed
81
82
83
    :param predict_only: bool
        If true only model outputs will be generated and returned. Metrics will not be evaluated

84
    :return
85
        Dictionary of results
86
    """
87
    random.seed(0)
88
    np.random.seed(1234)
89
90
91
    torch.manual_seed(
        1234
    )  # TODO: this may affect training runs that are run with evaluation mid-run.
92

93
94
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))

95
96
    if tasks is None:
        tasks = []
97
98
99
    assert (
        tasks != []
    ), "No tasks specified, or no tasks found. Please verify the task names."
100

lintangsutawika's avatar
lintangsutawika committed
101
102
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
103
        eval_logger.warning(
Baber Abbasi's avatar
Baber Abbasi committed
104
            "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!"
lintangsutawika's avatar
udate  
lintangsutawika committed
105
        )
lintangsutawika's avatar
lintangsutawika committed
106
107
108
        if gen_kwargs == "":
            gen_kwargs = None

109
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
110
111
        if model_args is None:
            model_args = ""
lintangsutawika's avatar
lintangsutawika committed
112
        lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
lintangsutawika's avatar
lintangsutawika committed
113
114
115
116
117
118
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
119
        )
120
    else:
121
        assert isinstance(model, lm_eval.api.model.LM)
122
        lm = model
123

haileyschoelkopf's avatar
haileyschoelkopf committed
124
125
126
127
128
129
130
    if use_cache is not None:
        print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
131
132
133
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
134
135
        )

136
137
138
139
140
141
142
143
    if task_manager is None:
        task_manager = TaskManager(verbosity)

    eval_logger.info(
        "get_task_dict has been updated to accept an optional argument, `task_manager`"
        "Read more here: https://github.com/EleutherAI/lm-evaluation-harness/blob/recursive-groups/docs/interface.md#external-library-usage"
        )
    task_dict = get_task_dict(tasks, task_manager)
144
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
145
        task_obj = task_dict[task_name]
146
        if isinstance(task_obj, tuple):
147
            _, task_obj = task_obj
148
149
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
150

Baber Abbasi's avatar
Baber Abbasi committed
151
152
153
154
155
156
157
158
159
160
161
162
163
        if task_obj.get_config("output_type") == "generate_until":
            if gen_kwargs is not None:
                task_obj.override_config(
                    key="generation_kwargs", value=gen_kwargs, update=True
                )

            if predict_only:
                log_samples = True
                eval_logger.info(
                    f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
                )
                # we have to change the class properties post-hoc. This is pretty hacky.
                task_obj.override_metric(metric_name="bypass")
164

165
        if num_fewshot is not None:
Baber Abbasi's avatar
Baber Abbasi committed
166
            if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
167
168
169
                eval_logger.info(
                    f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                )
170
            else:
Baber Abbasi's avatar
Baber Abbasi committed
171
172
173
174
                eval_logger.warning(
                    f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                )
                task_obj.override_config(key="num_fewshot", value=num_fewshot)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
175

Stephen Hogg's avatar
Stephen Hogg committed
176
    if check_integrity:
177
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
178

179
180
181
182
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
183
        bootstrap_iters=bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
184
        decontamination_ngrams_path=decontamination_ngrams_path,
185
        write_out=write_out,
186
        log_samples=log_samples,
187
        verbosity=verbosity,
188
    )
189

190
    if lm.rank == 0:
191
192
193
194
195
196
197
        if isinstance(model, str):
            model_name = model
        elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
            model_name = model.config._name_or_path
        else:
            model_name = type(model).__name__

198
199
        # add info about the model and few shot config
        results["config"] = {
200
            "model": model_name,
201
202
            "model_args": model_args,
            "batch_size": batch_size,
lintangsutawika's avatar
lintangsutawika committed
203
204
205
            "batch_sizes": list(lm.batch_sizes.values())
            if hasattr(lm, "batch_sizes")
            else [],
206
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
207
            "use_cache": use_cache,
208
209
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
lintangsutawika's avatar
lintangsutawika committed
210
            "gen_kwargs": gen_kwargs,
211
        }
212
        results["git_hash"] = get_git_commit_hash()
213
214
215
        return results
    else:
        return None
216

Leo Gao's avatar
Leo Gao committed
217

218
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
219

Fabrizio Milo's avatar
Fabrizio Milo committed
220

221
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
222
223
224
225
def evaluate(
    lm,
    task_dict,
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
226
    bootstrap_iters: int = 100000,
Fabrizio Milo's avatar
Fabrizio Milo committed
227
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
228
229
    write_out: bool = False,
    log_samples: bool = True,
230
    verbosity: str = "INFO",
Fabrizio Milo's avatar
Fabrizio Milo committed
231
):
232
233
234
235
236
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
237
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
238
239
240
241
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
242
    :param write_out: bool
243
244
245
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
246
247
248
    :return
        Dictionary of results
    """
249

250
    eval_logger.setLevel(getattr(logging, f"{verbosity}"))
lintangsutawika's avatar
lintangsutawika committed
251
    # decontaminate = decontamination_ngrams_path is not None
252

Baber Abbasi's avatar
Baber Abbasi committed
253
254
255
256
257
258
259
260
    for task_name, task in task_dict.items():
        if isinstance(task, tuple):
            _, task = task
        if not log_samples:
            assert (
                "bypass" not in getattr(task, "_metric_fn_list", {}).keys()
            ), f"log_samples must be True for 'bypass' only tasks: {task_name}"

261
    # stores the final result for each task, for each metric/filter pair.
Leo Gao's avatar
Leo Gao committed
262
    results = collections.defaultdict(dict)
263
    # Tracks each task's version.
Leo Gao's avatar
Leo Gao committed
264
    versions = collections.defaultdict(dict)
265
    # Tracks the YAML configs of all chosen tasks.
266
    configs = collections.defaultdict(dict)
267
    # logs info about each document evaluated.
lintangsutawika's avatar
lintangsutawika committed
268
    samples = collections.defaultdict(list)
269
    # tracks all Instances/requests a model must generate output on.
Leo Gao's avatar
Leo Gao committed
270
    requests = collections.defaultdict(list)
271
    # Aggregated task scores presented with groups
272
    results_agg = collections.defaultdict(dict)
273
    # Aggregated groups scores only
lintangsutawika's avatar
lintangsutawika committed
274
    groups_agg = collections.defaultdict(dict)
275
276
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
277
    padding_requests = collections.defaultdict(int)
lintangsutawika's avatar
lintangsutawika committed
278
    # store the hierarchy to do proper ordering
lintangsutawika's avatar
lintangsutawika committed
279
    task_hierarchy = collections.defaultdict(list)
280
281
    # store num-fewshot value per task
    num_fewshot = collections.defaultdict(int)
282

283
    # get lists of each type of request
284
    for task_name, task in task_dict.items():
285
        if isinstance(task, tuple):
lintangsutawika's avatar
lintangsutawika committed
286
287
            group_name, task = task
            task_hierarchy[group_name].append(task_name)
288
            versions[group_name] = "N/A"
lintangsutawika's avatar
lintangsutawika committed
289

290
        else:
291
            group_name = None
lintangsutawika's avatar
lintangsutawika committed
292
293
294
295
            task_hierarchy[task_name] = []

        if task is None:
            continue
296

Leo Gao's avatar
Leo Gao committed
297
        versions[task_name] = task.VERSION
haileyschoelkopf's avatar
haileyschoelkopf committed
298
299
        configs[task_name] = dict(task.dump_config())

300
        if "num_fewshot" in configs[task_name]:
301
302
303
304
            if configs[task_name]["metadata"]:
                n_shot = configs[task_name]["metadata"].get("num_fewshot", None)
            if not n_shot:
                n_shot = configs[task_name]["num_fewshot"]
305
        else:
306
            n_shot = 0 # TODO: is this always right?
307
308
        num_fewshot[task_name] = n_shot

lintangsutawika's avatar
lintangsutawika committed
309
        if "task_alias" in configs[task_name]:
Lintang Sutawika's avatar
Lintang Sutawika committed
310
            results[task_name]["alias"] = configs[task_name]["task_alias"]
lintangsutawika's avatar
lintangsutawika committed
311

lintangsutawika's avatar
format  
lintangsutawika committed
312
313
        if (
            ("group_alias" in configs[task_name])
Lintang Sutawika's avatar
Lintang Sutawika committed
314
            and (group_name not in results)
lintangsutawika's avatar
format  
lintangsutawika committed
315
            and (group_name is not None)
lintangsutawika's avatar
lintangsutawika committed
316
        ):
Lintang Sutawika's avatar
Lintang Sutawika committed
317
            results[group_name]["alias"] = configs[task_name]["group_alias"]
lintangsutawika's avatar
lintangsutawika committed
318

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
319
        if limit is not None:
320
321
322
323
324
325
            if task.has_test_docs():
                task_docs = task.test_docs()
            elif task.has_validation_docs():
                task_docs = task.validation_docs()
            else:
                raise RuntimeError("Task has neither test_docs nor validation_docs")
326
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
327

328
329
        task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)

330
        eval_logger.debug(
haileyschoelkopf's avatar
haileyschoelkopf committed
331
332
333
334
335
336
            f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
        )

        if write_out:
            for inst in task.instances:
                # print the prompt for the first few documents
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
337
338
                if inst.doc_id < 1:
                    eval_logger.info(
haileyschoelkopf's avatar
haileyschoelkopf committed
339
340
                        f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
haileyschoelkopf's avatar
haileyschoelkopf committed
341
                    )
haileyschoelkopf's avatar
haileyschoelkopf committed
342
                    eval_logger.info(f"Request: {str(inst)}")
haileyschoelkopf's avatar
haileyschoelkopf committed
343

344
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
345
346
347
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
348
349

        if lm.world_size > 1:
350
351
352
353
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
354

355
            # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
356
            numpad = max(gathered_item) - gathered_item[lm.rank]
357
            padding_requests[task.OUTPUT_TYPE] += numpad
358

359
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
360
361
    # execute each type of request
    for reqtype, reqs in requests.items():
362
        eval_logger.info(f"Running {reqtype} requests")
363
364
365
366
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
367

368
369
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
370
371
                cloned_reqs.extend([req] * req.repeats)

372
373
374
375
376
377
378
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

379
380
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
381

382
383
384
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
    for task_name, task in task_dict.items():
385
        if isinstance(task, tuple):
386
            group, task = task
387
388
            if task is None:
                continue
389
390
391
        task.apply_filters()

    ### Collect values of metrics on all datapoints ###
Leo Gao's avatar
Leo Gao committed
392
393
394
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
395
    for task_name, task in task_dict.items():
396
        if isinstance(task, tuple):
397
            group, task = task
398
399
            if task is None:
                continue
haileyschoelkopf's avatar
haileyschoelkopf committed
400
401
        # TODO: make it possible to use a different metric per filter
        # iterate over different filters used
402
        for key in task.instances[0].filtered_resps.keys():
403
404
405
406
            doc_iterator = (
                itertools.islice(
                    enumerate(task.test_docs()), lm.rank, limit, lm.world_size
                )
lintangsutawika's avatar
lintangsutawika committed
407
                if task.has_test_docs()
408
409
410
411
                else itertools.islice(
                    enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
                )
            )
412
            for doc_id, doc in doc_iterator:
413
414
                # subset instances to only this document id ; sort by idx
                requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
415
                requests.sort(key=lambda x: x.idx)
lintangsutawika's avatar
lintangsutawika committed
416
417
418
                metrics = task.process_results(
                    doc, [req.filtered_resps[key] for req in requests]
                )
419
420
421
422
423
424
425
426
427
428
429
430
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
                        "filtered_resps": [req.filtered_resps[key] for req in requests],
                    }
                    example.update(metrics)
                    samples[task_name].append(example)
431
432
433
                for metric, value in metrics.items():
                    vals[(task_name, key, metric)].append(value)

434
    if lm.world_size > 1:
435
        # if multigpu, then gather data across all ranks
436
437
438
439
440
441
442
443
        # first gather logged samples across all ranks
        for task_name, task_samples in list(samples.items()):
            full_samples = [None] * lm.world_size
            torch.distributed.all_gather_object(full_samples, task_samples)

            samples[task_name] = list(itertools.chain.from_iterable(full_samples))

        # then collect metrics across all ranks
444
445
        vals_torch = collections.defaultdict(list)
        for (task_name, key, metric), items in vals.items():
446
            numitem = 0
447
            if isinstance(items[0], tuple):
448
449
                numitem = len(items[0])

Lintang Sutawika's avatar
Lintang Sutawika committed
450
            if isinstance(items[0], (str, list, tuple)):
451
452
453
                # handle the string case
                gathered_items = [None] * lm.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_items, items)
454

455
                gathered_item = list(itertools.chain.from_iterable(gathered_items))
456
            else:
457
458
459
460
461
462
463
464
465
466
                # distributed gather requires all ranks to have same dimensions
                # so we pad out with float32 min value
                pad_value = torch.finfo(torch.float32).min
                metrics_tensor = torch.tensor(items, device=lm.device)

                original_dtype = metrics_tensor.dtype  # store original dtype
                torch_device_tensor = lm.accelerator.pad_across_processes(
                    metrics_tensor.to(torch.float32), pad_index=pad_value
                )
                gathered_item = lm.accelerator.gather(torch_device_tensor)
467

468
469
470
471
472
473
474
475
476
477
478
                if numitem > 0:
                    gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
                else:
                    gathered_filtered = gathered_item[gathered_item != pad_value]

                gathered_item = (
                    gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
                )
                # reconvert if we were passed a tuple of values
                if numitem > 0:
                    gathered_item = [tuple(g) for g in gathered_item]
479

480
481
            if lm.rank == 0:
                vals_torch[(task_name, key, metric)] = gathered_item
482

483
        vals = vals_torch
484

485
    if lm.rank == 0:
lintangsutawika's avatar
lintangsutawika committed
486

487
488
489
490
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
        for (task_name, key, metric), items in vals.items():
            task = task_dict[task_name]
lintangsutawika's avatar
lintangsutawika committed
491
492
            metric_key = metric + "," + key

493
            if isinstance(task, tuple):
lintangsutawika's avatar
lintangsutawika committed
494
495
496
497
                group_name, task = task
            else:
                group_name = None

498
            agg_fn = task.aggregation()[metric]
499
500
            results[task_name][metric_key] = agg_fn(items)
            results[task_name]["samples"] = len(items)
lintangsutawika's avatar
lintangsutawika committed
501

502
503
            # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
            # so we run them less iterations. still looking for a cleaner way to do this
haileyschoelkopf's avatar
haileyschoelkopf committed
504
            if bootstrap_iters > 0:
haileyschoelkopf's avatar
haileyschoelkopf committed
505
506
                stderr = lm_eval.api.metrics.stderr_for_metric(
                    metric=task.aggregation()[metric],
haileyschoelkopf's avatar
haileyschoelkopf committed
507
                    bootstrap_iters=min(bootstrap_iters, 100)
haileyschoelkopf's avatar
haileyschoelkopf committed
508
509
510
                    if metric in ["bleu", "chrf", "ter"]
                    else bootstrap_iters,
                )
511

lintangsutawika's avatar
lintangsutawika committed
512
                if stderr is not None and len(items) > 1:
haileyschoelkopf's avatar
haileyschoelkopf committed
513
                    results[task_name][metric + "_stderr" + "," + key] = stderr(items)
514
                else:
lintangsutawika's avatar
lintangsutawika committed
515
                    results[task_name][metric + "_stderr" + "," + key] = "N/A"
Fabrizio Milo's avatar
Fabrizio Milo committed
516

lintangsutawika's avatar
lintangsutawika committed
517
        if bool(results):
518
            for group, task_list in reversed(task_hierarchy.items()):
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
                if len(task_list) == 0:
                    # task_hierarchy entries are either
                    # `group_name: [subtask1, subtask2, ...]`
                    # or `task_name: []`.
                    # we only want to operate on groups here.
                    continue
                for metric in [
                    key for key in results[task_list[0]].keys() if "_stderr" not in key and key not in ["alias", "samples"]
                ]: # TODO: what if tasks don't all share the same metrics
                    stderr = "_stderr,".join(metric.split(","))

                    # gather metrics, sizes, and stderrs from subtasks
                    metrics = [results[task][metric] for task in task_list] # TODO: copy?
                    stderrs = [results[task][stderr] for task in task_list]
                    sizes = [results[task]["samples"] for task in task_list]

                    # compute group's pooled metric and stderr
                    results[group][metric] = lm_eval.api.metrics.aggregate_subtask_metrics(metrics, sizes)
                    # TODO: calculate grouped metric using aggregation fn
                    if "N/A" in stderrs:
                        results[group][stderr] = "N/A"
                    else:
                        results[group][stderr] = lm_eval.api.metrics.pooled_sample_stderr(stderrs, sizes)
                        # TODO: allow GroupConfigs to choose which variance formula is used, for back-compatibility
                        # To use the old (likely incorrect) variance formula, comment out the above and uncomment this line:
                        # results[group][stderr] = lm_eval.api.metrics.combined_sample_stderr(stderrs, sizes, metrics=metrics)

                    results[group]["samples"] = sum(sizes)
lintangsutawika's avatar
lintangsutawika committed
547

Lintang Sutawika's avatar
Lintang Sutawika committed
548
        def print_tasks(task_hierarchy, results, tab=0):
549
550
551
            results_agg = collections.defaultdict(dict)
            groups_agg = collections.defaultdict(dict)

Lintang Sutawika's avatar
Lintang Sutawika committed
552
553
            (group_name, task_list), *_ = task_hierarchy.items()
            task_list = sorted(task_list)
554

Lintang Sutawika's avatar
Lintang Sutawika committed
555
556
557
558
            results_agg[group_name] = results[group_name].copy()
            # results_agg[group_name]["tab"] = tab
            if "samples" in results_agg[group_name]:
                results_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
559

Lintang Sutawika's avatar
Lintang Sutawika committed
560
            tab_string = " " * tab + "- " if tab > 0 else ""
lintangsutawika's avatar
lintangsutawika committed
561

Lintang Sutawika's avatar
Lintang Sutawika committed
562
563
564
565
            if "alias" in results_agg[group_name]:
                results_agg[group_name]["alias"] = (
                    tab_string + results_agg[group_name]["alias"]
                )
lintangsutawika's avatar
lintangsutawika committed
566
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
567
                results_agg[group_name]["alias"] = tab_string + group_name
lintangsutawika's avatar
lintangsutawika committed
568

Lintang Sutawika's avatar
Lintang Sutawika committed
569
570
571
572
573
            if len(task_list) > 0:
                groups_agg[group_name] = results[group_name].copy()
                # groups_agg[group_name]["tab"] = tab
                if "samples" in groups_agg[group_name]:
                    groups_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
574

Lintang Sutawika's avatar
Lintang Sutawika committed
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
                if "alias" in groups_agg[group_name]:
                    groups_agg[group_name]["alias"] = (
                        tab_string + groups_agg[group_name]["alias"]
                    )
                else:
                    groups_agg[group_name]["alias"] = tab_string + group_name

                for task_name in task_list:
                    if task_name in task_hierarchy:
                        _task_hierarchy = {
                            **{task_name: task_hierarchy[task_name]},
                            **task_hierarchy,
                        }
                    else:
                        _task_hierarchy = {
                            **{task_name: []},
                            **task_hierarchy,
                        }

                    _results_agg, _groups_agg = print_tasks(
                        _task_hierarchy, results, tab + 1
                    )
                    results_agg = {**results_agg, **_results_agg}
                    groups_agg = {**groups_agg, **_groups_agg}

            return results_agg, groups_agg

        results_agg = collections.defaultdict(dict)
        groups_agg = collections.defaultdict(dict)
        all_tasks_list = list(task_hierarchy.keys())
        left_tasks_list = []
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
                k: v for k, v in task_hierarchy.items() if k in left_tasks_list
            }
            _results_agg, _groups_agg = print_tasks(_task_hierarchy, results)

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
619

620
        for group_name, task_list in task_hierarchy.items():
Lintang Sutawika's avatar
Lintang Sutawika committed
621
            if task_list != []:
622
                num_fewshot[group_name] = num_fewshot[task_list[0]] # TODO: validate this
623

624
        results_dict = {
625
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
626
            **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
627
628
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
629
            "n-shot": dict(sorted(num_fewshot.items())),
630
        }
631
632
633
634
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
635

636
637
    else:
        return None