evaluator.py 23.6 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import random
Leo Gao's avatar
Leo Gao committed
2
import itertools
FarzanehNakhaee's avatar
FarzanehNakhaee committed
3
import json
lintangsutawika's avatar
lintangsutawika committed
4
import collections
FarzanehNakhaee's avatar
FarzanehNakhaee committed
5
import sys
lintangsutawika's avatar
lintangsutawika committed
6

7
8
import torch

9
import numpy as np
lintangsutawika's avatar
lintangsutawika committed
10
11

import lm_eval.api
12
import lm_eval.tasks
lintangsutawika's avatar
lintangsutawika committed
13
import lm_eval.models
lintangsutawika's avatar
lintangsutawika committed
14
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
15
import lm_eval.api.registry
lintangsutawika's avatar
lintangsutawika committed
16

lintangsutawika's avatar
lintangsutawika committed
17
18
19
20
from lm_eval.utils import (
    positional_deprecated,
    run_task_tests,
    make_table,
21
    create_iterator,
lintangsutawika's avatar
lintangsutawika committed
22
    get_git_commit_hash,
lintangsutawika's avatar
lintangsutawika committed
23
    simple_parse_args_string,
lintangsutawika's avatar
lintangsutawika committed
24
    eval_logger,
lintangsutawika's avatar
lintangsutawika committed
25
)
26

Fabrizio Milo's avatar
Fabrizio Milo committed
27

28
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
29
30
31
32
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
33
    num_fewshot=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
34
    batch_size=None,
35
    max_batch_size=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
36
    device=None,
haileyschoelkopf's avatar
haileyschoelkopf committed
37
    use_cache=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
38
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
39
40
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
41
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
42
43
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
44
    gen_kwargs: str = None,
Fabrizio Milo's avatar
Fabrizio Milo committed
45
):
46
    """Instantiate and evaluate a model on a list of tasks.
47

48
49
50
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
51
        String arguments for each model class, see LM.create_from_arg_string.
52
53
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
54
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
55
56
    :param num_fewshot: int
        Number of examples in few-shot context
57
    :param batch_size: int or str, optional
58
        Batch size for model
59
60
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
61
    :param device: str, optional
62
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
63
64
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
65
66
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
67
68
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
69
70
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
71
    :param write_out: bool
72
73
74
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
75
76
77
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
78
    :return
79
        Dictionary of results
80
    """
81
    random.seed(0)
82
    np.random.seed(1234)
83
84
85
    torch.manual_seed(
        1234
    )  # TODO: this may affect training runs that are run with evaluation mid-run.
86

87
88
89
    assert (
        tasks != []
    ), "No tasks specified, or no tasks found. Please verify the task names."
90

lintangsutawika's avatar
lintangsutawika committed
91
92
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
93
94
95
        eval_logger.warning(
            f"generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
        )
lintangsutawika's avatar
lintangsutawika committed
96
97
98
        if gen_kwargs == "":
            gen_kwargs = None

99
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
100
101
        if model_args is None:
            model_args = ""
lintangsutawika's avatar
lintangsutawika committed
102
        lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
lintangsutawika's avatar
lintangsutawika committed
103
104
105
106
107
108
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
109
        )
110
    else:
111
        assert isinstance(model, lm_eval.api.model.LM)
112
        lm = model
113

haileyschoelkopf's avatar
haileyschoelkopf committed
114
115
116
117
118
119
120
121
122
123
    if use_cache is not None:
        print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
            + "_rank" + str(lm.rank) + ".db",
        )

124
125
    task_dict = lm_eval.tasks.get_task_dict(tasks)
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
126
127
128
        task_obj = task_dict[task_name]
        if type(task_obj) == tuple:
            group, task_obj = task_obj
129
130
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
131
132

        config = task_obj._config
lintangsutawika's avatar
udate  
lintangsutawika committed
133
        if config["output_type"] == "generate_until" and gen_kwargs is not None:
lintangsutawika's avatar
lintangsutawika committed
134
            config["generation_kwargs"].update(gen_kwargs)
135

136
        if num_fewshot is not None:
137
138
139
140
            if config["num_fewshot"] == 0:
                eval_logger.info(
                    f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                )
141
            else:
142
143
144
145
146
                default_num_fewshot = config["num_fewshot"]
                eval_logger.warning(
                    f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                )

147
                task_obj._config["num_fewshot"] = num_fewshot
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
148

Stephen Hogg's avatar
Stephen Hogg committed
149
    if check_integrity:
150
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
151

152
153
154
155
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
156
        bootstrap_iters=bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
157
        decontamination_ngrams_path=decontamination_ngrams_path,
158
        write_out=write_out,
159
        log_samples=log_samples,
160
    )
161

162
163
164
    if lm.rank == 0:
        # add info about the model and few shot config
        results["config"] = {
lintangsutawika's avatar
lintangsutawika committed
165
166
167
            "model": model
            if isinstance(model, str)
            else model.model.config._name_or_path,
168
169
            "model_args": model_args,
            "batch_size": batch_size,
lintangsutawika's avatar
lintangsutawika committed
170
171
172
            "batch_sizes": list(lm.batch_sizes.values())
            if hasattr(lm, "batch_sizes")
            else [],
173
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
174
            "use_cache": use_cache,
175
176
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
lintangsutawika's avatar
lintangsutawika committed
177
            "gen_kwargs": gen_kwargs,
178
        }
179
        results["git_hash"] = get_git_commit_hash()
180
181
182
        return results
    else:
        return None
183

Leo Gao's avatar
Leo Gao committed
184

185
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
186

Fabrizio Milo's avatar
Fabrizio Milo committed
187

188
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
189
190
191
192
def evaluate(
    lm,
    task_dict,
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
193
    bootstrap_iters: int = 100000,
Fabrizio Milo's avatar
Fabrizio Milo committed
194
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
195
196
    write_out: bool = False,
    log_samples: bool = True,
Fabrizio Milo's avatar
Fabrizio Milo committed
197
):
198
199
200
201
202
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
203
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
204
205
206
207
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
208
    :param write_out: bool
209
210
211
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
212
213
214
    :return
        Dictionary of results
    """
215

lintangsutawika's avatar
lintangsutawika committed
216
    # decontaminate = decontamination_ngrams_path is not None
217

218
    # stores the final result for each task, for each metric/filter pair.
Leo Gao's avatar
Leo Gao committed
219
    results = collections.defaultdict(dict)
220
    # Tracks each task's version.
Leo Gao's avatar
Leo Gao committed
221
    versions = collections.defaultdict(dict)
222
    # Tracks the YAML configs of all chosen tasks.
223
    configs = collections.defaultdict(dict)
224
    # logs info about each document evaluated.
lintangsutawika's avatar
lintangsutawika committed
225
    samples = collections.defaultdict(list)
226
    # tracks all Instances/requests a model must generate output on.
Leo Gao's avatar
Leo Gao committed
227
    requests = collections.defaultdict(list)
228
    # Aggregated task scores presented with groups
229
    results_agg = collections.defaultdict(dict)
230
    # Aggregated groups scores only
lintangsutawika's avatar
lintangsutawika committed
231
    groups_agg = collections.defaultdict(dict)
232
233
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
234
    padding_requests = collections.defaultdict(int)
lintangsutawika's avatar
lintangsutawika committed
235
    # store the hierarchy to do proper ordering
lintangsutawika's avatar
lintangsutawika committed
236
    task_hierarchy = collections.defaultdict(list)
237
238
    # store num-fewshot value per task
    num_fewshot = collections.defaultdict(int)
239

240
    # get lists of each type of request
241
    for task_name, task in task_dict.items():
242
        if type(task) == tuple:
lintangsutawika's avatar
lintangsutawika committed
243
244
            group_name, task = task
            task_hierarchy[group_name].append(task_name)
245
            versions[group_name] = "N/A"
lintangsutawika's avatar
lintangsutawika committed
246

247
        else:
248
            group_name = None
lintangsutawika's avatar
lintangsutawika committed
249
250
251
252
            task_hierarchy[task_name] = []

        if task is None:
            continue
253

Leo Gao's avatar
Leo Gao committed
254
        versions[task_name] = task.VERSION
haileyschoelkopf's avatar
haileyschoelkopf committed
255
256
        configs[task_name] = dict(task.dump_config())

257
258
259
        if "num_fewshot" in configs[task_name]:
            n_shot = configs[task_name]["num_fewshot"]
        else:
260
            n_shot = 0
261
262
        num_fewshot[task_name] = n_shot

lintangsutawika's avatar
lintangsutawika committed
263
        if "task_alias" in configs[task_name]:
Lintang Sutawika's avatar
Lintang Sutawika committed
264
            results[task_name]["alias"] = configs[task_name]["task_alias"]
lintangsutawika's avatar
lintangsutawika committed
265

lintangsutawika's avatar
format  
lintangsutawika committed
266
267
        if (
            ("group_alias" in configs[task_name])
Lintang Sutawika's avatar
Lintang Sutawika committed
268
            and (group_name not in results)
lintangsutawika's avatar
format  
lintangsutawika committed
269
            and (group_name is not None)
lintangsutawika's avatar
lintangsutawika committed
270
        ):
Lintang Sutawika's avatar
Lintang Sutawika committed
271
            results[group_name]["alias"] = configs[task_name]["group_alias"]
lintangsutawika's avatar
lintangsutawika committed
272

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
273
        if limit is not None:
274
275
276
277
278
279
            if task.has_test_docs():
                task_docs = task.test_docs()
            elif task.has_validation_docs():
                task_docs = task.validation_docs()
            else:
                raise RuntimeError("Task has neither test_docs nor validation_docs")
280
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
281

282
283
        task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)

284
        eval_logger.debug(
haileyschoelkopf's avatar
haileyschoelkopf committed
285
286
287
288
289
290
            f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
        )

        if write_out:
            for inst in task.instances:
                # print the prompt for the first few documents
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
291
292
                if inst.doc_id < 1:
                    eval_logger.info(
haileyschoelkopf's avatar
haileyschoelkopf committed
293
294
                        f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
haileyschoelkopf's avatar
haileyschoelkopf committed
295
                    )
haileyschoelkopf's avatar
haileyschoelkopf committed
296
                    eval_logger.info(f"Request: {str(inst)}")
haileyschoelkopf's avatar
haileyschoelkopf committed
297

298
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
299
300
301
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
302
303

        if lm.world_size > 1:
304
305
306
307
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
308

309
            # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
310
            numpad = max(gathered_item) - gathered_item[lm.rank]
311
            padding_requests[task.OUTPUT_TYPE] += numpad
312

313
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
314
315
    # execute each type of request
    for reqtype, reqs in requests.items():
lintangsutawika's avatar
lintangsutawika committed
316
        eval_logger.info("Running {} requests".format(reqtype))
317
318
319
320
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
321

322
323
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
324
325
                cloned_reqs.extend([req] * req.repeats)

326
327
328
329
330
331
332
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

333
334
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
335

336
337
338
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
    for task_name, task in task_dict.items():
339
340
        if type(task) == tuple:
            group, task = task
341
342
            if task is None:
                continue
343
344
345
        task.apply_filters()

    ### Collect values of metrics on all datapoints ###
Leo Gao's avatar
Leo Gao committed
346
347
348
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
349
    for task_name, task in task_dict.items():
350
351
        if type(task) == tuple:
            group, task = task
352
353
            if task is None:
                continue
haileyschoelkopf's avatar
haileyschoelkopf committed
354
355
        # TODO: make it possible to use a different metric per filter
        # iterate over different filters used
356
        for key in task.instances[0].filtered_resps.keys():
357
358
359
360
            doc_iterator = (
                itertools.islice(
                    enumerate(task.test_docs()), lm.rank, limit, lm.world_size
                )
lintangsutawika's avatar
lintangsutawika committed
361
                if task.has_test_docs()
362
363
364
365
                else itertools.islice(
                    enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
                )
            )
366
            for doc_id, doc in doc_iterator:
367
368
                # subset instances to only this document id ; sort by idx
                requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
369
                requests.sort(key=lambda x: x.idx)
lintangsutawika's avatar
lintangsutawika committed
370
371
372
                metrics = task.process_results(
                    doc, [req.filtered_resps[key] for req in requests]
                )
373
374
375
376
377
378
379
380
381
382
383
384
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
                        "filtered_resps": [req.filtered_resps[key] for req in requests],
                    }
                    example.update(metrics)
                    samples[task_name].append(example)
385
386
387
                for metric, value in metrics.items():
                    vals[(task_name, key, metric)].append(value)

388
    if lm.world_size > 1:
389
        # if multigpu, then gather data across all ranks
390
391
392
393
394
395
396
397
        # first gather logged samples across all ranks
        for task_name, task_samples in list(samples.items()):
            full_samples = [None] * lm.world_size
            torch.distributed.all_gather_object(full_samples, task_samples)

            samples[task_name] = list(itertools.chain.from_iterable(full_samples))

        # then collect metrics across all ranks
398
399
        vals_torch = collections.defaultdict(list)
        for (task_name, key, metric), items in vals.items():
400
            numitem = 0
401
            if type(items[0]) == tuple:
402
403
                numitem = len(items[0])

404
405
406
407
            if isinstance(items[0], (str, list)):
                # handle the string case
                gathered_items = [None] * lm.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_items, items)
408

409
                gathered_item = list(itertools.chain.from_iterable(gathered_items))
410
            else:
411
412
413
414
415
416
417
418
419
420
                # distributed gather requires all ranks to have same dimensions
                # so we pad out with float32 min value
                pad_value = torch.finfo(torch.float32).min
                metrics_tensor = torch.tensor(items, device=lm.device)

                original_dtype = metrics_tensor.dtype  # store original dtype
                torch_device_tensor = lm.accelerator.pad_across_processes(
                    metrics_tensor.to(torch.float32), pad_index=pad_value
                )
                gathered_item = lm.accelerator.gather(torch_device_tensor)
421

422
423
424
425
426
427
428
429
430
431
432
                if numitem > 0:
                    gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
                else:
                    gathered_filtered = gathered_item[gathered_item != pad_value]

                gathered_item = (
                    gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
                )
                # reconvert if we were passed a tuple of values
                if numitem > 0:
                    gathered_item = [tuple(g) for g in gathered_item]
433

434
435
            if lm.rank == 0:
                vals_torch[(task_name, key, metric)] = gathered_item
436

437
        vals = vals_torch
438

439
    if lm.rank == 0:
lintangsutawika's avatar
lintangsutawika committed
440

441
442
443
444
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
        for (task_name, key, metric), items in vals.items():
            task = task_dict[task_name]
lintangsutawika's avatar
lintangsutawika committed
445
446
            metric_key = metric + "," + key

447
            if type(task) == tuple:
lintangsutawika's avatar
lintangsutawika committed
448
449
450
451
                group_name, task = task
            else:
                group_name = None

452
            agg_fn = task.aggregation()[metric]
453
454
            results[task_name][metric_key] = agg_fn(items)
            results[task_name]["samples"] = len(items)
lintangsutawika's avatar
lintangsutawika committed
455

456
457
            # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
            # so we run them less iterations. still looking for a cleaner way to do this
haileyschoelkopf's avatar
haileyschoelkopf committed
458
            if bootstrap_iters > 0:
haileyschoelkopf's avatar
haileyschoelkopf committed
459
460
                stderr = lm_eval.api.metrics.stderr_for_metric(
                    metric=task.aggregation()[metric],
haileyschoelkopf's avatar
haileyschoelkopf committed
461
                    bootstrap_iters=min(bootstrap_iters, 100)
haileyschoelkopf's avatar
haileyschoelkopf committed
462
463
464
                    if metric in ["bleu", "chrf", "ter"]
                    else bootstrap_iters,
                )
465

lintangsutawika's avatar
lintangsutawika committed
466
                if stderr is not None and len(items) > 1:
haileyschoelkopf's avatar
haileyschoelkopf committed
467
                    results[task_name][metric + "_stderr" + "," + key] = stderr(items)
468
                else:
lintangsutawika's avatar
lintangsutawika committed
469
                    results[task_name][metric + "_stderr" + "," + key] = "N/A"
Fabrizio Milo's avatar
Fabrizio Milo committed
470

lintangsutawika's avatar
lintangsutawika committed
471
        if bool(results):
472
            for group, task_list in reversed(task_hierarchy.items()):
473
474
475
476
477
478
                if task_list == []:
                    total_size = results[group]["samples"]
                else:
                    total_size = 0

                    for task in task_list:
Lintang Sutawika's avatar
Lintang Sutawika committed
479
480
481
482
                        metrics = results[task].copy()

                        if "alias" in metrics:
                            metrics.pop("alias")
483
484
485
486
487
488
489
490
491
492
493
494
495
496

                        current_size = metrics.pop("samples")
                        # TODO: There should be a way for users
                        #       to toggle between weighted and
                        #       unweighted averaging
                        # For unweighted averaging, use:
                        #     current_size = 1

                        all_stderr = []
                        for metric in [
                            key for key in metrics.keys() if "_stderr" not in key
                        ]:
                            stderr = "_stderr,".join(metric.split(","))
                            stderr_score = results[task][stderr]
lintangsutawika's avatar
lintangsutawika committed
497
                            var_score = stderr_score**2
498
499
500
501
502
503
504
505
506
507
508
509
                            metric_score = results[task][metric]

                            all_stderr.append(stderr)

                            if metric in results[group]:
                                results[group][metric] = (
                                    results[group][metric] * total_size
                                    + metric_score * current_size
                                ) / (total_size + current_size)
                                # $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
                                results[group][stderr] = (
                                    (total_size - 1) * results[group][stderr]
lintangsutawika's avatar
lintangsutawika committed
510
                                    + (current_size - 1) * var_score
511
512
513
514
515
516
517
518
519
520
                                ) / (
                                    total_size + current_size - 1
                                ) + total_size * current_size / (
                                    (total_size + current_size)
                                    * (total_size + current_size - 1)
                                ) * (
                                    results[group][metric] - metric_score
                                ) ** 2
                            else:
                                results[group][metric] = metric_score
lintangsutawika's avatar
lintangsutawika committed
521
                                results[group][stderr] = var_score
522
523
524
525
526

                        total_size += current_size

                    for stderr in all_stderr:
                        results[group][stderr] = np.sqrt(results[group][stderr])
lintangsutawika's avatar
lintangsutawika committed
527

528
                results[group]["samples"] = total_size
lintangsutawika's avatar
lintangsutawika committed
529

Lintang Sutawika's avatar
Lintang Sutawika committed
530
        def print_tasks(task_hierarchy, results, tab=0):
531
532
533
            results_agg = collections.defaultdict(dict)
            groups_agg = collections.defaultdict(dict)

Lintang Sutawika's avatar
Lintang Sutawika committed
534
535
            (group_name, task_list), *_ = task_hierarchy.items()
            task_list = sorted(task_list)
536

Lintang Sutawika's avatar
Lintang Sutawika committed
537
538
539
540
            results_agg[group_name] = results[group_name].copy()
            # results_agg[group_name]["tab"] = tab
            if "samples" in results_agg[group_name]:
                results_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
541

Lintang Sutawika's avatar
Lintang Sutawika committed
542
            tab_string = " " * tab + "- " if tab > 0 else ""
lintangsutawika's avatar
lintangsutawika committed
543

Lintang Sutawika's avatar
Lintang Sutawika committed
544
545
546
547
            if "alias" in results_agg[group_name]:
                results_agg[group_name]["alias"] = (
                    tab_string + results_agg[group_name]["alias"]
                )
lintangsutawika's avatar
lintangsutawika committed
548
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
549
                results_agg[group_name]["alias"] = tab_string + group_name
lintangsutawika's avatar
lintangsutawika committed
550

Lintang Sutawika's avatar
Lintang Sutawika committed
551
552
553
554
555
            if len(task_list) > 0:
                groups_agg[group_name] = results[group_name].copy()
                # groups_agg[group_name]["tab"] = tab
                if "samples" in groups_agg[group_name]:
                    groups_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
556

Lintang Sutawika's avatar
Lintang Sutawika committed
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
                if "alias" in groups_agg[group_name]:
                    groups_agg[group_name]["alias"] = (
                        tab_string + groups_agg[group_name]["alias"]
                    )
                else:
                    groups_agg[group_name]["alias"] = tab_string + group_name

                for task_name in task_list:
                    if task_name in task_hierarchy:
                        _task_hierarchy = {
                            **{task_name: task_hierarchy[task_name]},
                            **task_hierarchy,
                        }
                    else:
                        _task_hierarchy = {
                            **{task_name: []},
                            **task_hierarchy,
                        }

                    _results_agg, _groups_agg = print_tasks(
                        _task_hierarchy, results, tab + 1
                    )
                    results_agg = {**results_agg, **_results_agg}
                    groups_agg = {**groups_agg, **_groups_agg}

            return results_agg, groups_agg

        results_agg = collections.defaultdict(dict)
        groups_agg = collections.defaultdict(dict)
        all_tasks_list = list(task_hierarchy.keys())
        left_tasks_list = []
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
                k: v for k, v in task_hierarchy.items() if k in left_tasks_list
            }
            _results_agg, _groups_agg = print_tasks(_task_hierarchy, results)

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
601

602
        for group_name, task_list in task_hierarchy.items():
Lintang Sutawika's avatar
Lintang Sutawika committed
603
604
            if task_list != []:
                num_fewshot[group_name] = num_fewshot[task_list[0]]
605

606
        results_dict = {
607
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
608
            **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
609
610
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
611
            "n-shot": dict(sorted(num_fewshot.items())),
612
        }
613
614
615
616
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
617

618
619
    else:
        return None