evaluator.py 23.3 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import random
Leo Gao's avatar
Leo Gao committed
2
import itertools
lintangsutawika's avatar
lintangsutawika committed
3
4
import collections

5
6
import torch

7
import numpy as np
lintangsutawika's avatar
lintangsutawika committed
8
9

import lm_eval.api
10
import lm_eval.tasks
lintangsutawika's avatar
lintangsutawika committed
11
import lm_eval.models
lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
13
import lm_eval.api.registry
lintangsutawika's avatar
lintangsutawika committed
14

lintangsutawika's avatar
lintangsutawika committed
15
16
17
18
from lm_eval.utils import (
    positional_deprecated,
    run_task_tests,
    get_git_commit_hash,
lintangsutawika's avatar
lintangsutawika committed
19
    simple_parse_args_string,
lintangsutawika's avatar
lintangsutawika committed
20
    eval_logger,
lintangsutawika's avatar
lintangsutawika committed
21
)
22

Fabrizio Milo's avatar
Fabrizio Milo committed
23

24
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
25
26
27
28
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
29
    num_fewshot=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
30
    batch_size=None,
31
    max_batch_size=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
32
    device=None,
haileyschoelkopf's avatar
haileyschoelkopf committed
33
    use_cache=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
34
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
35
36
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
37
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
38
39
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
40
    gen_kwargs: str = None,
Fabrizio Milo's avatar
Fabrizio Milo committed
41
):
42
    """Instantiate and evaluate a model on a list of tasks.
43

44
45
46
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
47
        String arguments for each model class, see LM.create_from_arg_string.
48
        Ignored if `model` argument is a LM object.
49
50
    :param tasks: list[Task]
        List of Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
51
52
    :param num_fewshot: int
        Number of examples in few-shot context
53
    :param batch_size: int or str, optional
54
        Batch size for model
55
56
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
57
    :param device: str, optional
58
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
59
60
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
61
62
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
63
64
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
65
66
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
67
    :param write_out: bool
68
69
70
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
71
72
73
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
74
    :return
75
        Dictionary of results
76
    """
77
    random.seed(0)
78
    np.random.seed(1234)
79
80
81
    torch.manual_seed(
        1234
    )  # TODO: this may affect training runs that are run with evaluation mid-run.
82

83
84
85
    assert (
        tasks != []
    ), "No tasks specified, or no tasks found. Please verify the task names."
86

lintangsutawika's avatar
lintangsutawika committed
87
88
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
89
        eval_logger.warning(
90
            "generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
lintangsutawika's avatar
udate  
lintangsutawika committed
91
        )
lintangsutawika's avatar
lintangsutawika committed
92
93
94
        if gen_kwargs == "":
            gen_kwargs = None

95
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
96
97
        if model_args is None:
            model_args = ""
lintangsutawika's avatar
lintangsutawika committed
98
        lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
lintangsutawika's avatar
lintangsutawika committed
99
100
101
102
103
104
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
105
        )
106
    else:
107
        assert isinstance(model, lm_eval.api.model.LM)
108
        lm = model
109

haileyschoelkopf's avatar
haileyschoelkopf committed
110
111
112
113
114
115
116
    if use_cache is not None:
        print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
117
118
119
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
120
121
        )

lintangsutawika's avatar
lintangsutawika committed
122
    task_dict = tasks
123
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
124
125
        task_obj = task_dict[task_name]
        if type(task_obj) == tuple:
lintangsutawika's avatar
lintangsutawika committed
126
            _, task_obj = task_obj
127
128
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
129
130

        config = task_obj._config
lintangsutawika's avatar
udate  
lintangsutawika committed
131
        if config["output_type"] == "generate_until" and gen_kwargs is not None:
lintangsutawika's avatar
lintangsutawika committed
132
            config["generation_kwargs"].update(gen_kwargs)
133

134
        if num_fewshot is not None:
135
136
137
138
            if config["num_fewshot"] == 0:
                eval_logger.info(
                    f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                )
139
            else:
140
141
142
143
144
                default_num_fewshot = config["num_fewshot"]
                eval_logger.warning(
                    f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                )

145
                task_obj._config["num_fewshot"] = num_fewshot
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
146

Stephen Hogg's avatar
Stephen Hogg committed
147
    if check_integrity:
148
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
149

150
151
152
153
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
154
        bootstrap_iters=bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
155
        decontamination_ngrams_path=decontamination_ngrams_path,
156
        write_out=write_out,
157
        log_samples=log_samples,
158
    )
159

160
161
162
    if lm.rank == 0:
        # add info about the model and few shot config
        results["config"] = {
lintangsutawika's avatar
lintangsutawika committed
163
164
165
            "model": model
            if isinstance(model, str)
            else model.model.config._name_or_path,
166
167
            "model_args": model_args,
            "batch_size": batch_size,
lintangsutawika's avatar
lintangsutawika committed
168
169
170
            "batch_sizes": list(lm.batch_sizes.values())
            if hasattr(lm, "batch_sizes")
            else [],
171
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
172
            "use_cache": use_cache,
173
174
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
lintangsutawika's avatar
lintangsutawika committed
175
            "gen_kwargs": gen_kwargs,
176
        }
177
        results["git_hash"] = get_git_commit_hash()
178
179
180
        return results
    else:
        return None
181

Leo Gao's avatar
Leo Gao committed
182

183
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
184

Fabrizio Milo's avatar
Fabrizio Milo committed
185

186
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
187
188
189
190
def evaluate(
    lm,
    task_dict,
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
191
    bootstrap_iters: int = 100000,
Fabrizio Milo's avatar
Fabrizio Milo committed
192
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
193
194
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
195
    weight_by_size: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
196
):
197
198
199
200
201
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
202
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
203
204
205
206
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
207
    :param write_out: bool
208
209
210
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
211
212
213
    :return
        Dictionary of results
    """
214

lintangsutawika's avatar
lintangsutawika committed
215
    # decontaminate = decontamination_ngrams_path is not None
216

217
    # stores the final result for each task, for each metric/filter pair.
Leo Gao's avatar
Leo Gao committed
218
    results = collections.defaultdict(dict)
219
    # Tracks each task's version.
Leo Gao's avatar
Leo Gao committed
220
    versions = collections.defaultdict(dict)
221
    # Tracks the YAML configs of all chosen tasks.
222
    configs = collections.defaultdict(dict)
223
    # logs info about each document evaluated.
lintangsutawika's avatar
lintangsutawika committed
224
    samples = collections.defaultdict(list)
225
    # tracks all Instances/requests a model must generate output on.
Leo Gao's avatar
Leo Gao committed
226
    requests = collections.defaultdict(list)
227
    # Aggregated task scores presented with groups
228
    results_agg = collections.defaultdict(dict)
229
    # Aggregated groups scores only
lintangsutawika's avatar
lintangsutawika committed
230
    groups_agg = collections.defaultdict(dict)
231
232
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
233
    padding_requests = collections.defaultdict(int)
lintangsutawika's avatar
lintangsutawika committed
234
    # store the hierarchy to do proper ordering
lintangsutawika's avatar
lintangsutawika committed
235
    task_hierarchy = collections.defaultdict(list)
236
237
    # store num-fewshot value per task
    num_fewshot = collections.defaultdict(int)
238

239
    # get lists of each type of request
240
    for task_name, task in task_dict.items():
241
        if type(task) == tuple:
lintangsutawika's avatar
lintangsutawika committed
242
243
            group_name, task = task
            task_hierarchy[group_name].append(task_name)
244
            versions[group_name] = "N/A"
lintangsutawika's avatar
lintangsutawika committed
245

246
        else:
247
            group_name = None
lintangsutawika's avatar
lintangsutawika committed
248
249
250
251
            task_hierarchy[task_name] = []

        if task is None:
            continue
252

Leo Gao's avatar
Leo Gao committed
253
        versions[task_name] = task.VERSION
haileyschoelkopf's avatar
haileyschoelkopf committed
254
255
        configs[task_name] = dict(task.dump_config())

256
257
258
        if "num_fewshot" in configs[task_name]:
            n_shot = configs[task_name]["num_fewshot"]
        else:
259
            n_shot = 0
260
261
        num_fewshot[task_name] = n_shot

lintangsutawika's avatar
lintangsutawika committed
262
        if "task_alias" in configs[task_name]:
Lintang Sutawika's avatar
Lintang Sutawika committed
263
            results[task_name]["alias"] = configs[task_name]["task_alias"]
lintangsutawika's avatar
lintangsutawika committed
264

lintangsutawika's avatar
format  
lintangsutawika committed
265
266
        if (
            ("group_alias" in configs[task_name])
Lintang Sutawika's avatar
Lintang Sutawika committed
267
            and (group_name not in results)
lintangsutawika's avatar
format  
lintangsutawika committed
268
            and (group_name is not None)
lintangsutawika's avatar
lintangsutawika committed
269
        ):
Lintang Sutawika's avatar
Lintang Sutawika committed
270
            results[group_name]["alias"] = configs[task_name]["group_alias"]
lintangsutawika's avatar
lintangsutawika committed
271

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
272
        if limit is not None:
273
274
275
276
277
278
            if task.has_test_docs():
                task_docs = task.test_docs()
            elif task.has_validation_docs():
                task_docs = task.validation_docs()
            else:
                raise RuntimeError("Task has neither test_docs nor validation_docs")
279
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
280

281
282
        task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)

283
        eval_logger.debug(
haileyschoelkopf's avatar
haileyschoelkopf committed
284
285
286
287
288
289
            f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
        )

        if write_out:
            for inst in task.instances:
                # print the prompt for the first few documents
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
290
291
                if inst.doc_id < 1:
                    eval_logger.info(
haileyschoelkopf's avatar
haileyschoelkopf committed
292
293
                        f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
haileyschoelkopf's avatar
haileyschoelkopf committed
294
                    )
haileyschoelkopf's avatar
haileyschoelkopf committed
295
                    eval_logger.info(f"Request: {str(inst)}")
haileyschoelkopf's avatar
haileyschoelkopf committed
296

297
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
298
299
300
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
301
302

        if lm.world_size > 1:
303
304
305
306
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
307

308
            # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
309
            numpad = max(gathered_item) - gathered_item[lm.rank]
310
            padding_requests[task.OUTPUT_TYPE] += numpad
311

312
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
313
314
    # execute each type of request
    for reqtype, reqs in requests.items():
lintangsutawika's avatar
lintangsutawika committed
315
        eval_logger.info("Running {} requests".format(reqtype))
316
317
318
319
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
320

321
322
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
323
324
                cloned_reqs.extend([req] * req.repeats)

325
326
327
328
329
330
331
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

332
333
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
334

335
336
337
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
    for task_name, task in task_dict.items():
338
339
        if type(task) == tuple:
            group, task = task
340
341
            if task is None:
                continue
342
343
344
        task.apply_filters()

    ### Collect values of metrics on all datapoints ###
Leo Gao's avatar
Leo Gao committed
345
346
347
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
348
    for task_name, task in task_dict.items():
349
350
        if type(task) == tuple:
            group, task = task
351
352
            if task is None:
                continue
haileyschoelkopf's avatar
haileyschoelkopf committed
353
354
        # TODO: make it possible to use a different metric per filter
        # iterate over different filters used
355
        for key in task.instances[0].filtered_resps.keys():
356
357
358
359
            doc_iterator = (
                itertools.islice(
                    enumerate(task.test_docs()), lm.rank, limit, lm.world_size
                )
lintangsutawika's avatar
lintangsutawika committed
360
                if task.has_test_docs()
361
362
363
364
                else itertools.islice(
                    enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
                )
            )
365
            for doc_id, doc in doc_iterator:
366
367
                # subset instances to only this document id ; sort by idx
                requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
368
                requests.sort(key=lambda x: x.idx)
lintangsutawika's avatar
lintangsutawika committed
369
370
371
                metrics = task.process_results(
                    doc, [req.filtered_resps[key] for req in requests]
                )
372
373
374
375
376
377
378
379
380
381
382
383
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
                        "filtered_resps": [req.filtered_resps[key] for req in requests],
                    }
                    example.update(metrics)
                    samples[task_name].append(example)
384
385
386
                for metric, value in metrics.items():
                    vals[(task_name, key, metric)].append(value)

387
    if lm.world_size > 1:
388
        # if multigpu, then gather data across all ranks
389
390
391
392
393
394
395
396
        # first gather logged samples across all ranks
        for task_name, task_samples in list(samples.items()):
            full_samples = [None] * lm.world_size
            torch.distributed.all_gather_object(full_samples, task_samples)

            samples[task_name] = list(itertools.chain.from_iterable(full_samples))

        # then collect metrics across all ranks
397
398
        vals_torch = collections.defaultdict(list)
        for (task_name, key, metric), items in vals.items():
399
            numitem = 0
400
            if type(items[0]) == tuple:
401
402
                numitem = len(items[0])

Lintang Sutawika's avatar
Lintang Sutawika committed
403
            if isinstance(items[0], (str, list, tuple)):
404
405
406
                # handle the string case
                gathered_items = [None] * lm.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_items, items)
407

408
                gathered_item = list(itertools.chain.from_iterable(gathered_items))
409
            else:
410
411
412
413
414
415
416
417
418
419
                # distributed gather requires all ranks to have same dimensions
                # so we pad out with float32 min value
                pad_value = torch.finfo(torch.float32).min
                metrics_tensor = torch.tensor(items, device=lm.device)

                original_dtype = metrics_tensor.dtype  # store original dtype
                torch_device_tensor = lm.accelerator.pad_across_processes(
                    metrics_tensor.to(torch.float32), pad_index=pad_value
                )
                gathered_item = lm.accelerator.gather(torch_device_tensor)
420

421
422
423
424
425
426
427
428
429
430
431
                if numitem > 0:
                    gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
                else:
                    gathered_filtered = gathered_item[gathered_item != pad_value]

                gathered_item = (
                    gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
                )
                # reconvert if we were passed a tuple of values
                if numitem > 0:
                    gathered_item = [tuple(g) for g in gathered_item]
432

433
434
            if lm.rank == 0:
                vals_torch[(task_name, key, metric)] = gathered_item
435

436
        vals = vals_torch
437

438
    if lm.rank == 0:
lintangsutawika's avatar
lintangsutawika committed
439

440
441
442
443
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
        for (task_name, key, metric), items in vals.items():
            task = task_dict[task_name]
lintangsutawika's avatar
lintangsutawika committed
444
445
            metric_key = metric + "," + key

446
            if type(task) == tuple:
lintangsutawika's avatar
lintangsutawika committed
447
448
449
450
                group_name, task = task
            else:
                group_name = None

451
            agg_fn = task.aggregation()[metric]
452
453
            results[task_name][metric_key] = agg_fn(items)
            results[task_name]["samples"] = len(items)
lintangsutawika's avatar
lintangsutawika committed
454

455
456
            # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
            # so we run them less iterations. still looking for a cleaner way to do this
haileyschoelkopf's avatar
haileyschoelkopf committed
457
            if bootstrap_iters > 0:
haileyschoelkopf's avatar
haileyschoelkopf committed
458
459
                stderr = lm_eval.api.metrics.stderr_for_metric(
                    metric=task.aggregation()[metric],
haileyschoelkopf's avatar
haileyschoelkopf committed
460
                    bootstrap_iters=min(bootstrap_iters, 100)
haileyschoelkopf's avatar
haileyschoelkopf committed
461
462
463
                    if metric in ["bleu", "chrf", "ter"]
                    else bootstrap_iters,
                )
464

lintangsutawika's avatar
lintangsutawika committed
465
                if stderr is not None and len(items) > 1:
haileyschoelkopf's avatar
haileyschoelkopf committed
466
                    results[task_name][metric + "_stderr" + "," + key] = stderr(items)
467
                else:
lintangsutawika's avatar
lintangsutawika committed
468
                    results[task_name][metric + "_stderr" + "," + key] = "N/A"
Fabrizio Milo's avatar
Fabrizio Milo committed
469

lintangsutawika's avatar
lintangsutawika committed
470
        if bool(results):
471
            for group, task_list in reversed(task_hierarchy.items()):
472
473
474
475
476
477
                if task_list == []:
                    total_size = results[group]["samples"]
                else:
                    total_size = 0

                    for task in task_list:
Lintang Sutawika's avatar
Lintang Sutawika committed
478
479
480
481
                        metrics = results[task].copy()

                        if "alias" in metrics:
                            metrics.pop("alias")
482

lintangsutawika's avatar
lintangsutawika committed
483
484
485
                        if weight_by_size:
                            current_size = metrics.pop("samples")
                        else:
lintangsutawika's avatar
lintangsutawika committed
486
                            metrics.pop("samples")
lintangsutawika's avatar
lintangsutawika committed
487
                            current_size = 1
488
489
490
491
492
493
494

                        all_stderr = []
                        for metric in [
                            key for key in metrics.keys() if "_stderr" not in key
                        ]:
                            stderr = "_stderr,".join(metric.split(","))
                            stderr_score = results[task][stderr]
lintangsutawika's avatar
lintangsutawika committed
495
                            var_score = stderr_score**2
496
497
498
499
500
501
502
503
504
505
506
507
                            metric_score = results[task][metric]

                            all_stderr.append(stderr)

                            if metric in results[group]:
                                results[group][metric] = (
                                    results[group][metric] * total_size
                                    + metric_score * current_size
                                ) / (total_size + current_size)
                                # $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
                                results[group][stderr] = (
                                    (total_size - 1) * results[group][stderr]
lintangsutawika's avatar
lintangsutawika committed
508
                                    + (current_size - 1) * var_score
509
510
511
512
513
                                ) / (
                                    total_size + current_size - 1
                                ) + total_size * current_size / (
                                    (total_size + current_size)
                                    * (total_size + current_size - 1)
514
                                ) * (results[group][metric] - metric_score) ** 2
515
516
                            else:
                                results[group][metric] = metric_score
lintangsutawika's avatar
lintangsutawika committed
517
                                results[group][stderr] = var_score
518
519
520
521
522

                        total_size += current_size

                    for stderr in all_stderr:
                        results[group][stderr] = np.sqrt(results[group][stderr])
lintangsutawika's avatar
lintangsutawika committed
523

524
                results[group]["samples"] = total_size
lintangsutawika's avatar
lintangsutawika committed
525

Lintang Sutawika's avatar
Lintang Sutawika committed
526
        def print_tasks(task_hierarchy, results, tab=0):
527
528
529
            results_agg = collections.defaultdict(dict)
            groups_agg = collections.defaultdict(dict)

Lintang Sutawika's avatar
Lintang Sutawika committed
530
531
            (group_name, task_list), *_ = task_hierarchy.items()
            task_list = sorted(task_list)
532

Lintang Sutawika's avatar
Lintang Sutawika committed
533
534
535
536
            results_agg[group_name] = results[group_name].copy()
            # results_agg[group_name]["tab"] = tab
            if "samples" in results_agg[group_name]:
                results_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
537

Lintang Sutawika's avatar
Lintang Sutawika committed
538
            tab_string = " " * tab + "- " if tab > 0 else ""
lintangsutawika's avatar
lintangsutawika committed
539

Lintang Sutawika's avatar
Lintang Sutawika committed
540
541
542
543
            if "alias" in results_agg[group_name]:
                results_agg[group_name]["alias"] = (
                    tab_string + results_agg[group_name]["alias"]
                )
lintangsutawika's avatar
lintangsutawika committed
544
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
545
                results_agg[group_name]["alias"] = tab_string + group_name
lintangsutawika's avatar
lintangsutawika committed
546

Lintang Sutawika's avatar
Lintang Sutawika committed
547
548
549
550
551
            if len(task_list) > 0:
                groups_agg[group_name] = results[group_name].copy()
                # groups_agg[group_name]["tab"] = tab
                if "samples" in groups_agg[group_name]:
                    groups_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
552

Lintang Sutawika's avatar
Lintang Sutawika committed
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
                if "alias" in groups_agg[group_name]:
                    groups_agg[group_name]["alias"] = (
                        tab_string + groups_agg[group_name]["alias"]
                    )
                else:
                    groups_agg[group_name]["alias"] = tab_string + group_name

                for task_name in task_list:
                    if task_name in task_hierarchy:
                        _task_hierarchy = {
                            **{task_name: task_hierarchy[task_name]},
                            **task_hierarchy,
                        }
                    else:
                        _task_hierarchy = {
                            **{task_name: []},
                            **task_hierarchy,
                        }

                    _results_agg, _groups_agg = print_tasks(
                        _task_hierarchy, results, tab + 1
                    )
                    results_agg = {**results_agg, **_results_agg}
                    groups_agg = {**groups_agg, **_groups_agg}

            return results_agg, groups_agg

        results_agg = collections.defaultdict(dict)
        groups_agg = collections.defaultdict(dict)
        all_tasks_list = list(task_hierarchy.keys())
        left_tasks_list = []
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
                k: v for k, v in task_hierarchy.items() if k in left_tasks_list
            }
            _results_agg, _groups_agg = print_tasks(_task_hierarchy, results)

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
597

598
        for group_name, task_list in task_hierarchy.items():
Lintang Sutawika's avatar
Lintang Sutawika committed
599
600
            if task_list != []:
                num_fewshot[group_name] = num_fewshot[task_list[0]]
601

602
        results_dict = {
603
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
604
            **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
605
606
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
607
            "n-shot": dict(sorted(num_fewshot.items())),
608
        }
609
610
611
612
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
613

614
615
    else:
        return None