evaluator.py 24.3 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import random
Leo Gao's avatar
Leo Gao committed
2
import itertools
lintangsutawika's avatar
lintangsutawika committed
3
4
import collections

5
6
import torch

7
import numpy as np
lintangsutawika's avatar
lintangsutawika committed
8
9

import lm_eval.api
10
import lm_eval.tasks
lintangsutawika's avatar
lintangsutawika committed
11
import lm_eval.models
lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
13
import lm_eval.api.registry
lintangsutawika's avatar
lintangsutawika committed
14

lintangsutawika's avatar
lintangsutawika committed
15
16
17
18
from lm_eval.utils import (
    positional_deprecated,
    run_task_tests,
    get_git_commit_hash,
lintangsutawika's avatar
lintangsutawika committed
19
    simple_parse_args_string,
lintangsutawika's avatar
lintangsutawika committed
20
    eval_logger,
lintangsutawika's avatar
lintangsutawika committed
21
)
22

Fabrizio Milo's avatar
Fabrizio Milo committed
23

24
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
25
26
27
def simple_evaluate(
    model,
    model_args=None,
28
    tasks=None,
29
    num_fewshot=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
30
    batch_size=None,
31
    max_batch_size=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
32
    device=None,
haileyschoelkopf's avatar
haileyschoelkopf committed
33
    use_cache=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
34
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
35
36
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
37
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
38
39
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
40
    gen_kwargs: str = None,
Fabrizio Milo's avatar
Fabrizio Milo committed
41
):
42
    """Instantiate and evaluate a model on a list of tasks.
43

44
45
46
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
47
        String arguments for each model class, see LM.create_from_arg_string.
48
49
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
50
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
51
52
    :param num_fewshot: int
        Number of examples in few-shot context
53
    :param batch_size: int or str, optional
54
        Batch size for model
55
56
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
57
    :param device: str, optional
58
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
59
60
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
61
62
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
63
64
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
65
66
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
67
    :param write_out: bool
68
69
70
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
71
72
73
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
74
    :return
75
        Dictionary of results
76
    """
77
    random.seed(0)
78
    np.random.seed(1234)
79
80
81
    torch.manual_seed(
        1234
    )  # TODO: this may affect training runs that are run with evaluation mid-run.
82

83
84
    if tasks is None:
        tasks = []
85
86
87
    assert (
        tasks != []
    ), "No tasks specified, or no tasks found. Please verify the task names."
88

lintangsutawika's avatar
lintangsutawika committed
89
90
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
91
        eval_logger.warning(
92
            "generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
lintangsutawika's avatar
udate  
lintangsutawika committed
93
        )
lintangsutawika's avatar
lintangsutawika committed
94
95
96
        if gen_kwargs == "":
            gen_kwargs = None

97
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
98
99
        if model_args is None:
            model_args = ""
lintangsutawika's avatar
lintangsutawika committed
100
        lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
lintangsutawika's avatar
lintangsutawika committed
101
102
103
104
105
106
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
107
        )
108
    else:
109
        assert isinstance(model, lm_eval.api.model.LM)
110
        lm = model
111

haileyschoelkopf's avatar
haileyschoelkopf committed
112
113
114
115
116
117
118
    if use_cache is not None:
        print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
119
120
121
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
122
123
        )

124
125
    task_dict = lm_eval.tasks.get_task_dict(tasks)
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
126
        task_obj = task_dict[task_name]
127
        if isinstance(task_obj, tuple):
lintangsutawika's avatar
lintangsutawika committed
128
            group, task_obj = task_obj
129
130
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
131
132

        config = task_obj._config
lintangsutawika's avatar
udate  
lintangsutawika committed
133
        if config["output_type"] == "generate_until" and gen_kwargs is not None:
lintangsutawika's avatar
lintangsutawika committed
134
            config["generation_kwargs"].update(gen_kwargs)
135

136
        if num_fewshot is not None:
137
138
139
140
            if config["num_fewshot"] == 0:
                eval_logger.info(
                    f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                )
141
            else:
142
                default_num_fewshot = config["num_fewshot"]
143
144
145
146
147
148
                if default_num_fewshot:
                    # warn a user, if a specific num_fewshot > 0 was specified.
                    # if unspecified in config, no warning message
                    eval_logger.warning(
                        f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                    )
149

150
                task_obj._config["num_fewshot"] = num_fewshot
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
151

Stephen Hogg's avatar
Stephen Hogg committed
152
    if check_integrity:
153
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
154

155
156
157
158
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
159
        bootstrap_iters=bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
160
        decontamination_ngrams_path=decontamination_ngrams_path,
161
        write_out=write_out,
162
        log_samples=log_samples,
163
    )
164

165
    if lm.rank == 0:
166
167
168
169
170
171
172
        if isinstance(model, str):
            model_name = model
        elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
            model_name = model.config._name_or_path
        else:
            model_name = type(model).__name__

173
174
        # add info about the model and few shot config
        results["config"] = {
175
            "model": model_name,
176
177
            "model_args": model_args,
            "batch_size": batch_size,
lintangsutawika's avatar
lintangsutawika committed
178
179
180
            "batch_sizes": list(lm.batch_sizes.values())
            if hasattr(lm, "batch_sizes")
            else [],
181
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
182
            "use_cache": use_cache,
183
184
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
lintangsutawika's avatar
lintangsutawika committed
185
            "gen_kwargs": gen_kwargs,
186
        }
187
        results["git_hash"] = get_git_commit_hash()
188
189
190
        return results
    else:
        return None
191

Leo Gao's avatar
Leo Gao committed
192

193
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
194

Fabrizio Milo's avatar
Fabrizio Milo committed
195

196
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
197
198
199
200
def evaluate(
    lm,
    task_dict,
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
201
    bootstrap_iters: int = 100000,
Fabrizio Milo's avatar
Fabrizio Milo committed
202
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
203
204
    write_out: bool = False,
    log_samples: bool = True,
Fabrizio Milo's avatar
Fabrizio Milo committed
205
):
206
207
208
209
210
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
211
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
212
213
214
215
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
216
    :param write_out: bool
217
218
219
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
220
221
222
    :return
        Dictionary of results
    """
223

lintangsutawika's avatar
lintangsutawika committed
224
    # decontaminate = decontamination_ngrams_path is not None
225

226
    # stores the final result for each task, for each metric/filter pair.
Leo Gao's avatar
Leo Gao committed
227
    results = collections.defaultdict(dict)
228
    # Tracks each task's version.
Leo Gao's avatar
Leo Gao committed
229
    versions = collections.defaultdict(dict)
230
    # Tracks the YAML configs of all chosen tasks.
231
    configs = collections.defaultdict(dict)
232
    # logs info about each document evaluated.
lintangsutawika's avatar
lintangsutawika committed
233
    samples = collections.defaultdict(list)
234
    # tracks all Instances/requests a model must generate output on.
Leo Gao's avatar
Leo Gao committed
235
    requests = collections.defaultdict(list)
236
    # Aggregated task scores presented with groups
237
    results_agg = collections.defaultdict(dict)
238
    # Aggregated groups scores only
lintangsutawika's avatar
lintangsutawika committed
239
    groups_agg = collections.defaultdict(dict)
240
241
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
242
    padding_requests = collections.defaultdict(int)
lintangsutawika's avatar
lintangsutawika committed
243
    # store the hierarchy to do proper ordering
lintangsutawika's avatar
lintangsutawika committed
244
    task_hierarchy = collections.defaultdict(list)
245
246
    # store num-fewshot value per task
    num_fewshot = collections.defaultdict(int)
247

248
    # get lists of each type of request
249
    for task_name, task in task_dict.items():
250
        if isinstance(task, tuple):
lintangsutawika's avatar
lintangsutawika committed
251
252
            group_name, task = task
            task_hierarchy[group_name].append(task_name)
253
            versions[group_name] = "N/A"
lintangsutawika's avatar
lintangsutawika committed
254

255
        else:
256
            group_name = None
lintangsutawika's avatar
lintangsutawika committed
257
258
259
260
            task_hierarchy[task_name] = []

        if task is None:
            continue
261

Leo Gao's avatar
Leo Gao committed
262
        versions[task_name] = task.VERSION
haileyschoelkopf's avatar
haileyschoelkopf committed
263
264
        configs[task_name] = dict(task.dump_config())

265
266
267
        if "num_fewshot" in configs[task_name]:
            n_shot = configs[task_name]["num_fewshot"]
        else:
268
            n_shot = 0
269
270
        num_fewshot[task_name] = n_shot

lintangsutawika's avatar
lintangsutawika committed
271
        if "task_alias" in configs[task_name]:
Lintang Sutawika's avatar
Lintang Sutawika committed
272
            results[task_name]["alias"] = configs[task_name]["task_alias"]
lintangsutawika's avatar
lintangsutawika committed
273

lintangsutawika's avatar
format  
lintangsutawika committed
274
275
        if (
            ("group_alias" in configs[task_name])
Lintang Sutawika's avatar
Lintang Sutawika committed
276
            and (group_name not in results)
lintangsutawika's avatar
format  
lintangsutawika committed
277
            and (group_name is not None)
lintangsutawika's avatar
lintangsutawika committed
278
        ):
Lintang Sutawika's avatar
Lintang Sutawika committed
279
            results[group_name]["alias"] = configs[task_name]["group_alias"]
lintangsutawika's avatar
lintangsutawika committed
280

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
281
        if limit is not None:
282
283
284
285
286
287
            if task.has_test_docs():
                task_docs = task.test_docs()
            elif task.has_validation_docs():
                task_docs = task.validation_docs()
            else:
                raise RuntimeError("Task has neither test_docs nor validation_docs")
288
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
289

290
291
        task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)

292
        eval_logger.debug(
haileyschoelkopf's avatar
haileyschoelkopf committed
293
294
295
296
297
298
            f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
        )

        if write_out:
            for inst in task.instances:
                # print the prompt for the first few documents
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
299
300
                if inst.doc_id < 1:
                    eval_logger.info(
haileyschoelkopf's avatar
haileyschoelkopf committed
301
302
                        f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
haileyschoelkopf's avatar
haileyschoelkopf committed
303
                    )
haileyschoelkopf's avatar
haileyschoelkopf committed
304
                    eval_logger.info(f"Request: {str(inst)}")
haileyschoelkopf's avatar
haileyschoelkopf committed
305

306
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
307
308
309
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
310
311

        if lm.world_size > 1:
312
313
314
315
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
316

317
            # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
318
            numpad = max(gathered_item) - gathered_item[lm.rank]
319
            padding_requests[task.OUTPUT_TYPE] += numpad
320

321
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
322
323
    # execute each type of request
    for reqtype, reqs in requests.items():
324
        eval_logger.info(f"Running {reqtype} requests")
325
326
327
328
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
329

330
331
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
332
333
                cloned_reqs.extend([req] * req.repeats)

334
335
336
337
338
339
340
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

341
342
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
343

344
345
346
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
    for task_name, task in task_dict.items():
347
        if isinstance(task, tuple):
348
            group, task = task
349
350
            if task is None:
                continue
351
352
353
        task.apply_filters()

    ### Collect values of metrics on all datapoints ###
Leo Gao's avatar
Leo Gao committed
354
355
356
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
357
    for task_name, task in task_dict.items():
358
        if isinstance(task, tuple):
359
            group, task = task
360
361
            if task is None:
                continue
haileyschoelkopf's avatar
haileyschoelkopf committed
362
363
        # TODO: make it possible to use a different metric per filter
        # iterate over different filters used
364
        for key in task.instances[0].filtered_resps.keys():
365
366
367
368
            doc_iterator = (
                itertools.islice(
                    enumerate(task.test_docs()), lm.rank, limit, lm.world_size
                )
lintangsutawika's avatar
lintangsutawika committed
369
                if task.has_test_docs()
370
371
372
373
                else itertools.islice(
                    enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
                )
            )
374
            for doc_id, doc in doc_iterator:
375
376
                # subset instances to only this document id ; sort by idx
                requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
377
                requests.sort(key=lambda x: x.idx)
lintangsutawika's avatar
lintangsutawika committed
378
379
380
                metrics = task.process_results(
                    doc, [req.filtered_resps[key] for req in requests]
                )
381
382
383
384
385
386
387
388
389
390
391
392
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
                        "filtered_resps": [req.filtered_resps[key] for req in requests],
                    }
                    example.update(metrics)
                    samples[task_name].append(example)
393
394
395
                for metric, value in metrics.items():
                    vals[(task_name, key, metric)].append(value)

396
    if lm.world_size > 1:
397
        # if multigpu, then gather data across all ranks
398
399
400
401
402
403
404
405
        # first gather logged samples across all ranks
        for task_name, task_samples in list(samples.items()):
            full_samples = [None] * lm.world_size
            torch.distributed.all_gather_object(full_samples, task_samples)

            samples[task_name] = list(itertools.chain.from_iterable(full_samples))

        # then collect metrics across all ranks
406
407
        vals_torch = collections.defaultdict(list)
        for (task_name, key, metric), items in vals.items():
408
            numitem = 0
409
            if isinstance(items[0], tuple):
410
411
                numitem = len(items[0])

Lintang Sutawika's avatar
Lintang Sutawika committed
412
            if isinstance(items[0], (str, list, tuple)):
413
414
415
                # handle the string case
                gathered_items = [None] * lm.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_items, items)
416

417
                gathered_item = list(itertools.chain.from_iterable(gathered_items))
418
            else:
419
420
421
422
423
424
425
426
427
428
                # distributed gather requires all ranks to have same dimensions
                # so we pad out with float32 min value
                pad_value = torch.finfo(torch.float32).min
                metrics_tensor = torch.tensor(items, device=lm.device)

                original_dtype = metrics_tensor.dtype  # store original dtype
                torch_device_tensor = lm.accelerator.pad_across_processes(
                    metrics_tensor.to(torch.float32), pad_index=pad_value
                )
                gathered_item = lm.accelerator.gather(torch_device_tensor)
429

430
431
432
433
434
435
436
437
438
439
440
                if numitem > 0:
                    gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
                else:
                    gathered_filtered = gathered_item[gathered_item != pad_value]

                gathered_item = (
                    gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
                )
                # reconvert if we were passed a tuple of values
                if numitem > 0:
                    gathered_item = [tuple(g) for g in gathered_item]
441

442
443
            if lm.rank == 0:
                vals_torch[(task_name, key, metric)] = gathered_item
444

445
        vals = vals_torch
446

447
    if lm.rank == 0:
lintangsutawika's avatar
lintangsutawika committed
448

449
450
451
452
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
        for (task_name, key, metric), items in vals.items():
            task = task_dict[task_name]
lintangsutawika's avatar
lintangsutawika committed
453
454
            metric_key = metric + "," + key

455
            if isinstance(task, tuple):
lintangsutawika's avatar
lintangsutawika committed
456
457
458
459
                group_name, task = task
            else:
                group_name = None

460
            agg_fn = task.aggregation()[metric]
461
462
            results[task_name][metric_key] = agg_fn(items)
            results[task_name]["samples"] = len(items)
lintangsutawika's avatar
lintangsutawika committed
463

464
465
            # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
            # so we run them less iterations. still looking for a cleaner way to do this
haileyschoelkopf's avatar
haileyschoelkopf committed
466
            if bootstrap_iters > 0:
haileyschoelkopf's avatar
haileyschoelkopf committed
467
468
                stderr = lm_eval.api.metrics.stderr_for_metric(
                    metric=task.aggregation()[metric],
haileyschoelkopf's avatar
haileyschoelkopf committed
469
                    bootstrap_iters=min(bootstrap_iters, 100)
haileyschoelkopf's avatar
haileyschoelkopf committed
470
471
472
                    if metric in ["bleu", "chrf", "ter"]
                    else bootstrap_iters,
                )
473

lintangsutawika's avatar
lintangsutawika committed
474
                if stderr is not None and len(items) > 1:
haileyschoelkopf's avatar
haileyschoelkopf committed
475
                    results[task_name][metric + "_stderr" + "," + key] = stderr(items)
476
                else:
lintangsutawika's avatar
lintangsutawika committed
477
                    results[task_name][metric + "_stderr" + "," + key] = "N/A"
Fabrizio Milo's avatar
Fabrizio Milo committed
478

lintangsutawika's avatar
lintangsutawika committed
479
        if bool(results):
480
            for group, task_list in reversed(task_hierarchy.items()):
481
482
483
484
485
486
                if task_list == []:
                    total_size = results[group]["samples"]
                else:
                    total_size = 0

                    for task in task_list:
Lintang Sutawika's avatar
Lintang Sutawika committed
487
488
489
490
                        metrics = results[task].copy()

                        if "alias" in metrics:
                            metrics.pop("alias")
491
492
493
494
495
496
497
498
499
500
501
502
503
504

                        current_size = metrics.pop("samples")
                        # TODO: There should be a way for users
                        #       to toggle between weighted and
                        #       unweighted averaging
                        # For unweighted averaging, use:
                        #     current_size = 1

                        all_stderr = []
                        for metric in [
                            key for key in metrics.keys() if "_stderr" not in key
                        ]:
                            stderr = "_stderr,".join(metric.split(","))
                            stderr_score = results[task][stderr]
505
506
507
508
509
                            if stderr_score == "N/A":
                                var_score = "N/A"
                            else:
                                var_score = stderr_score**2
                                all_stderr.append(stderr)
510

511
                            metric_score = results[task][metric]
512
513
514
515
516
517
518

                            if metric in results[group]:
                                results[group][metric] = (
                                    results[group][metric] * total_size
                                    + metric_score * current_size
                                ) / (total_size + current_size)
                                # $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
519
                                if var_score == "N/A" or results[group][stderr] == "N/A":
520
521
522
523
524
525
526
527
528
529
530
531
532
                                    results[group][stderr] = "N/A"
                                else:
                                    results[group][stderr] = (
                                        (total_size - 1) * results[group][stderr]
                                        + (current_size - 1) * var_score
                                    ) / (
                                        total_size + current_size - 1
                                    ) + total_size * current_size / (
                                        (total_size + current_size)
                                        * (total_size + current_size - 1)
                                    ) * (
                                        results[group][metric] - metric_score
                                    ) ** 2
533
534
                            else:
                                results[group][metric] = metric_score
lintangsutawika's avatar
lintangsutawika committed
535
                                results[group][stderr] = var_score
536
537
538
539
540

                        total_size += current_size

                    for stderr in all_stderr:
                        results[group][stderr] = np.sqrt(results[group][stderr])
lintangsutawika's avatar
lintangsutawika committed
541

542
                results[group]["samples"] = total_size
lintangsutawika's avatar
lintangsutawika committed
543

Lintang Sutawika's avatar
Lintang Sutawika committed
544
        def print_tasks(task_hierarchy, results, tab=0):
545
546
547
            results_agg = collections.defaultdict(dict)
            groups_agg = collections.defaultdict(dict)

Lintang Sutawika's avatar
Lintang Sutawika committed
548
549
            (group_name, task_list), *_ = task_hierarchy.items()
            task_list = sorted(task_list)
550

Lintang Sutawika's avatar
Lintang Sutawika committed
551
552
553
554
            results_agg[group_name] = results[group_name].copy()
            # results_agg[group_name]["tab"] = tab
            if "samples" in results_agg[group_name]:
                results_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
555

Lintang Sutawika's avatar
Lintang Sutawika committed
556
            tab_string = " " * tab + "- " if tab > 0 else ""
lintangsutawika's avatar
lintangsutawika committed
557

Lintang Sutawika's avatar
Lintang Sutawika committed
558
559
560
561
            if "alias" in results_agg[group_name]:
                results_agg[group_name]["alias"] = (
                    tab_string + results_agg[group_name]["alias"]
                )
lintangsutawika's avatar
lintangsutawika committed
562
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
563
                results_agg[group_name]["alias"] = tab_string + group_name
lintangsutawika's avatar
lintangsutawika committed
564

Lintang Sutawika's avatar
Lintang Sutawika committed
565
566
567
568
569
            if len(task_list) > 0:
                groups_agg[group_name] = results[group_name].copy()
                # groups_agg[group_name]["tab"] = tab
                if "samples" in groups_agg[group_name]:
                    groups_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
570

Lintang Sutawika's avatar
Lintang Sutawika committed
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
                if "alias" in groups_agg[group_name]:
                    groups_agg[group_name]["alias"] = (
                        tab_string + groups_agg[group_name]["alias"]
                    )
                else:
                    groups_agg[group_name]["alias"] = tab_string + group_name

                for task_name in task_list:
                    if task_name in task_hierarchy:
                        _task_hierarchy = {
                            **{task_name: task_hierarchy[task_name]},
                            **task_hierarchy,
                        }
                    else:
                        _task_hierarchy = {
                            **{task_name: []},
                            **task_hierarchy,
                        }

                    _results_agg, _groups_agg = print_tasks(
                        _task_hierarchy, results, tab + 1
                    )
                    results_agg = {**results_agg, **_results_agg}
                    groups_agg = {**groups_agg, **_groups_agg}

            return results_agg, groups_agg

        results_agg = collections.defaultdict(dict)
        groups_agg = collections.defaultdict(dict)
        all_tasks_list = list(task_hierarchy.keys())
        left_tasks_list = []
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
                k: v for k, v in task_hierarchy.items() if k in left_tasks_list
            }
            _results_agg, _groups_agg = print_tasks(_task_hierarchy, results)

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
615

616
        for group_name, task_list in task_hierarchy.items():
Lintang Sutawika's avatar
Lintang Sutawika committed
617
618
            if task_list != []:
                num_fewshot[group_name] = num_fewshot[task_list[0]]
619

620
        results_dict = {
621
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
622
            **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
623
624
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
625
            "n-shot": dict(sorted(num_fewshot.items())),
626
        }
627
628
629
630
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
631

632
633
    else:
        return None