evaluator.py 23.7 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import random
Leo Gao's avatar
Leo Gao committed
2
import itertools
lintangsutawika's avatar
lintangsutawika committed
3
4
import collections

5
6
import torch

7
import numpy as np
lintangsutawika's avatar
lintangsutawika committed
8
9

import lm_eval.api
10
import lm_eval.tasks
lintangsutawika's avatar
lintangsutawika committed
11
import lm_eval.models
lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
13
import lm_eval.api.registry
lintangsutawika's avatar
lintangsutawika committed
14

lintangsutawika's avatar
lintangsutawika committed
15
16
17
18
from lm_eval.utils import (
    positional_deprecated,
    run_task_tests,
    get_git_commit_hash,
lintangsutawika's avatar
lintangsutawika committed
19
    simple_parse_args_string,
lintangsutawika's avatar
lintangsutawika committed
20
    eval_logger,
lintangsutawika's avatar
lintangsutawika committed
21
)
22

Fabrizio Milo's avatar
Fabrizio Milo committed
23

24
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
25
26
27
28
def simple_evaluate(
    model,
    model_args=None,
    tasks=[],
29
    num_fewshot=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
30
    batch_size=None,
31
    max_batch_size=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
32
    device=None,
haileyschoelkopf's avatar
haileyschoelkopf committed
33
    use_cache=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
34
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
35
36
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
37
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
38
39
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
40
    gen_kwargs: str = None,
lintangsutawika's avatar
lintangsutawika committed
41
    weight_by_size: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
42
):
43
    """Instantiate and evaluate a model on a list of tasks.
44

45
46
47
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
48
        String arguments for each model class, see LM.create_from_arg_string.
49
        Ignored if `model` argument is a LM object.
50
51
    :param tasks: list[Task]
        List of Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
52
53
    :param num_fewshot: int
        Number of examples in few-shot context
54
    :param batch_size: int or str, optional
55
        Batch size for model
56
57
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
58
    :param device: str, optional
59
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
60
61
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
62
63
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
64
65
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
66
67
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
68
    :param write_out: bool
69
70
71
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
72
73
74
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
75
    :return
76
        Dictionary of results
77
    """
78
    random.seed(0)
79
    np.random.seed(1234)
80
81
82
    torch.manual_seed(
        1234
    )  # TODO: this may affect training runs that are run with evaluation mid-run.
83

84
85
86
    assert (
        tasks != []
    ), "No tasks specified, or no tasks found. Please verify the task names."
87

lintangsutawika's avatar
lintangsutawika committed
88
89
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
90
        eval_logger.warning(
91
            "generation_kwargs specified through cli, these settings will be used over set parameters in yaml tasks."
lintangsutawika's avatar
udate  
lintangsutawika committed
92
        )
lintangsutawika's avatar
lintangsutawika committed
93
94
95
        if gen_kwargs == "":
            gen_kwargs = None

96
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
97
98
        if model_args is None:
            model_args = ""
lintangsutawika's avatar
lintangsutawika committed
99
        lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
lintangsutawika's avatar
lintangsutawika committed
100
101
102
103
104
105
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
106
        )
107
    else:
108
        assert isinstance(model, lm_eval.api.model.LM)
109
        lm = model
110

haileyschoelkopf's avatar
haileyschoelkopf committed
111
112
113
114
115
116
117
    if use_cache is not None:
        print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
118
119
120
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
121
122
        )

lintangsutawika's avatar
lintangsutawika committed
123
    task_dict = tasks
124
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
125
126
127
        task_obj = task_dict[task_name]
        if type(task_obj) == tuple:
            group, task_obj = task_obj
128
129
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
130
131

        config = task_obj._config
lintangsutawika's avatar
udate  
lintangsutawika committed
132
        if config["output_type"] == "generate_until" and gen_kwargs is not None:
lintangsutawika's avatar
lintangsutawika committed
133
            config["generation_kwargs"].update(gen_kwargs)
134

135
        if num_fewshot is not None:
136
137
138
139
            if config["num_fewshot"] == 0:
                eval_logger.info(
                    f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                )
140
            else:
141
142
143
144
145
                default_num_fewshot = config["num_fewshot"]
                eval_logger.warning(
                    f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                )

146
                task_obj._config["num_fewshot"] = num_fewshot
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
147

Stephen Hogg's avatar
Stephen Hogg committed
148
    if check_integrity:
149
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
150

151
152
153
154
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
155
        bootstrap_iters=bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
156
        decontamination_ngrams_path=decontamination_ngrams_path,
157
        write_out=write_out,
158
        log_samples=log_samples,
lintangsutawika's avatar
lintangsutawika committed
159
        weight_by_size=weight_by_size,
160
    )
161

162
163
164
    if lm.rank == 0:
        # add info about the model and few shot config
        results["config"] = {
lintangsutawika's avatar
lintangsutawika committed
165
166
167
            "model": model
            if isinstance(model, str)
            else model.model.config._name_or_path,
168
169
            "model_args": model_args,
            "batch_size": batch_size,
lintangsutawika's avatar
lintangsutawika committed
170
171
172
            "batch_sizes": list(lm.batch_sizes.values())
            if hasattr(lm, "batch_sizes")
            else [],
173
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
174
            "use_cache": use_cache,
175
176
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
lintangsutawika's avatar
lintangsutawika committed
177
            "gen_kwargs": gen_kwargs,
178
        }
179
        results["git_hash"] = get_git_commit_hash()
180
181
182
        return results
    else:
        return None
183

Leo Gao's avatar
Leo Gao committed
184

185
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
186

Fabrizio Milo's avatar
Fabrizio Milo committed
187

188
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
189
190
191
192
def evaluate(
    lm,
    task_dict,
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
193
    bootstrap_iters: int = 100000,
Fabrizio Milo's avatar
Fabrizio Milo committed
194
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
195
196
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
197
    weight_by_size: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
198
):
199
200
201
202
203
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
204
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
205
206
207
208
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
209
    :param write_out: bool
210
211
212
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
213
214
215
    :return
        Dictionary of results
    """
216

lintangsutawika's avatar
lintangsutawika committed
217
    # decontaminate = decontamination_ngrams_path is not None
218

219
    # stores the final result for each task, for each metric/filter pair.
Leo Gao's avatar
Leo Gao committed
220
    results = collections.defaultdict(dict)
221
    # Tracks each task's version.
Leo Gao's avatar
Leo Gao committed
222
    versions = collections.defaultdict(dict)
223
    # Tracks the YAML configs of all chosen tasks.
224
    configs = collections.defaultdict(dict)
225
    # logs info about each document evaluated.
lintangsutawika's avatar
lintangsutawika committed
226
    samples = collections.defaultdict(list)
227
    # tracks all Instances/requests a model must generate output on.
Leo Gao's avatar
Leo Gao committed
228
    requests = collections.defaultdict(list)
229
    # Aggregated task scores presented with groups
230
    results_agg = collections.defaultdict(dict)
231
    # Aggregated groups scores only
lintangsutawika's avatar
lintangsutawika committed
232
    groups_agg = collections.defaultdict(dict)
233
234
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
235
    padding_requests = collections.defaultdict(int)
lintangsutawika's avatar
lintangsutawika committed
236
    # store the hierarchy to do proper ordering
lintangsutawika's avatar
lintangsutawika committed
237
    task_hierarchy = collections.defaultdict(list)
238
239
    # store num-fewshot value per task
    num_fewshot = collections.defaultdict(int)
240

241
    # get lists of each type of request
242
    for task_name, task in task_dict.items():
243
        if type(task) == tuple:
lintangsutawika's avatar
lintangsutawika committed
244
245
            group_name, task = task
            task_hierarchy[group_name].append(task_name)
246
            versions[group_name] = "N/A"
lintangsutawika's avatar
lintangsutawika committed
247

248
        else:
249
            group_name = None
lintangsutawika's avatar
lintangsutawika committed
250
251
252
253
            task_hierarchy[task_name] = []

        if task is None:
            continue
254

Leo Gao's avatar
Leo Gao committed
255
        versions[task_name] = task.VERSION
haileyschoelkopf's avatar
haileyschoelkopf committed
256
257
        configs[task_name] = dict(task.dump_config())

258
259
260
        if "num_fewshot" in configs[task_name]:
            n_shot = configs[task_name]["num_fewshot"]
        else:
261
            n_shot = 0
262
263
        num_fewshot[task_name] = n_shot

lintangsutawika's avatar
lintangsutawika committed
264
        if "task_alias" in configs[task_name]:
Lintang Sutawika's avatar
Lintang Sutawika committed
265
            results[task_name]["alias"] = configs[task_name]["task_alias"]
lintangsutawika's avatar
lintangsutawika committed
266

lintangsutawika's avatar
format  
lintangsutawika committed
267
268
        if (
            ("group_alias" in configs[task_name])
Lintang Sutawika's avatar
Lintang Sutawika committed
269
            and (group_name not in results)
lintangsutawika's avatar
format  
lintangsutawika committed
270
            and (group_name is not None)
lintangsutawika's avatar
lintangsutawika committed
271
        ):
Lintang Sutawika's avatar
Lintang Sutawika committed
272
            results[group_name]["alias"] = configs[task_name]["group_alias"]
lintangsutawika's avatar
lintangsutawika committed
273

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
274
        if limit is not None:
275
276
277
278
279
280
            if task.has_test_docs():
                task_docs = task.test_docs()
            elif task.has_validation_docs():
                task_docs = task.validation_docs()
            else:
                raise RuntimeError("Task has neither test_docs nor validation_docs")
281
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
282

283
284
        task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)

285
        eval_logger.debug(
haileyschoelkopf's avatar
haileyschoelkopf committed
286
287
288
289
290
291
            f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
        )

        if write_out:
            for inst in task.instances:
                # print the prompt for the first few documents
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
292
293
                if inst.doc_id < 1:
                    eval_logger.info(
haileyschoelkopf's avatar
haileyschoelkopf committed
294
295
                        f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
haileyschoelkopf's avatar
haileyschoelkopf committed
296
                    )
haileyschoelkopf's avatar
haileyschoelkopf committed
297
                    eval_logger.info(f"Request: {str(inst)}")
haileyschoelkopf's avatar
haileyschoelkopf committed
298

299
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
300
301
302
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
303
304

        if lm.world_size > 1:
305
306
307
308
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
309

310
            # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
311
            numpad = max(gathered_item) - gathered_item[lm.rank]
312
            padding_requests[task.OUTPUT_TYPE] += numpad
313

314
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
315
316
    # execute each type of request
    for reqtype, reqs in requests.items():
lintangsutawika's avatar
lintangsutawika committed
317
        eval_logger.info("Running {} requests".format(reqtype))
318
319
320
321
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
322

323
324
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
325
326
                cloned_reqs.extend([req] * req.repeats)

327
328
329
330
331
332
333
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

334
335
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
336

337
338
339
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
    for task_name, task in task_dict.items():
340
341
        if type(task) == tuple:
            group, task = task
342
343
            if task is None:
                continue
344
345
346
        task.apply_filters()

    ### Collect values of metrics on all datapoints ###
Leo Gao's avatar
Leo Gao committed
347
348
349
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
350
    for task_name, task in task_dict.items():
351
352
        if type(task) == tuple:
            group, task = task
353
354
            if task is None:
                continue
haileyschoelkopf's avatar
haileyschoelkopf committed
355
356
        # TODO: make it possible to use a different metric per filter
        # iterate over different filters used
357
        for key in task.instances[0].filtered_resps.keys():
358
359
360
361
            doc_iterator = (
                itertools.islice(
                    enumerate(task.test_docs()), lm.rank, limit, lm.world_size
                )
lintangsutawika's avatar
lintangsutawika committed
362
                if task.has_test_docs()
363
364
365
366
                else itertools.islice(
                    enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
                )
            )
367
            for doc_id, doc in doc_iterator:
368
369
                # subset instances to only this document id ; sort by idx
                requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
370
                requests.sort(key=lambda x: x.idx)
lintangsutawika's avatar
lintangsutawika committed
371
372
373
                metrics = task.process_results(
                    doc, [req.filtered_resps[key] for req in requests]
                )
374
375
376
377
378
379
380
381
382
383
384
385
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
                        "filtered_resps": [req.filtered_resps[key] for req in requests],
                    }
                    example.update(metrics)
                    samples[task_name].append(example)
386
387
388
                for metric, value in metrics.items():
                    vals[(task_name, key, metric)].append(value)

389
    if lm.world_size > 1:
390
        # if multigpu, then gather data across all ranks
391
392
393
394
395
396
397
398
        # first gather logged samples across all ranks
        for task_name, task_samples in list(samples.items()):
            full_samples = [None] * lm.world_size
            torch.distributed.all_gather_object(full_samples, task_samples)

            samples[task_name] = list(itertools.chain.from_iterable(full_samples))

        # then collect metrics across all ranks
399
400
        vals_torch = collections.defaultdict(list)
        for (task_name, key, metric), items in vals.items():
401
            numitem = 0
402
            if type(items[0]) == tuple:
403
404
                numitem = len(items[0])

Lintang Sutawika's avatar
Lintang Sutawika committed
405
            if isinstance(items[0], (str, list, tuple)):
406
407
408
                # handle the string case
                gathered_items = [None] * lm.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_items, items)
409

410
                gathered_item = list(itertools.chain.from_iterable(gathered_items))
411
            else:
412
413
414
415
416
417
418
419
420
421
                # distributed gather requires all ranks to have same dimensions
                # so we pad out with float32 min value
                pad_value = torch.finfo(torch.float32).min
                metrics_tensor = torch.tensor(items, device=lm.device)

                original_dtype = metrics_tensor.dtype  # store original dtype
                torch_device_tensor = lm.accelerator.pad_across_processes(
                    metrics_tensor.to(torch.float32), pad_index=pad_value
                )
                gathered_item = lm.accelerator.gather(torch_device_tensor)
422

423
424
425
426
427
428
429
430
431
432
433
                if numitem > 0:
                    gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
                else:
                    gathered_filtered = gathered_item[gathered_item != pad_value]

                gathered_item = (
                    gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
                )
                # reconvert if we were passed a tuple of values
                if numitem > 0:
                    gathered_item = [tuple(g) for g in gathered_item]
434

435
436
            if lm.rank == 0:
                vals_torch[(task_name, key, metric)] = gathered_item
437

438
        vals = vals_torch
439

440
    if lm.rank == 0:
lintangsutawika's avatar
lintangsutawika committed
441

442
443
444
445
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
        for (task_name, key, metric), items in vals.items():
            task = task_dict[task_name]
lintangsutawika's avatar
lintangsutawika committed
446
447
            metric_key = metric + "," + key

448
            if type(task) == tuple:
lintangsutawika's avatar
lintangsutawika committed
449
450
451
452
                group_name, task = task
            else:
                group_name = None

453
            agg_fn = task.aggregation()[metric]
454
455
            results[task_name][metric_key] = agg_fn(items)
            results[task_name]["samples"] = len(items)
lintangsutawika's avatar
lintangsutawika committed
456

457
458
            # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
            # so we run them less iterations. still looking for a cleaner way to do this
haileyschoelkopf's avatar
haileyschoelkopf committed
459
            if bootstrap_iters > 0:
haileyschoelkopf's avatar
haileyschoelkopf committed
460
461
                stderr = lm_eval.api.metrics.stderr_for_metric(
                    metric=task.aggregation()[metric],
haileyschoelkopf's avatar
haileyschoelkopf committed
462
                    bootstrap_iters=min(bootstrap_iters, 100)
haileyschoelkopf's avatar
haileyschoelkopf committed
463
464
465
                    if metric in ["bleu", "chrf", "ter"]
                    else bootstrap_iters,
                )
466

lintangsutawika's avatar
lintangsutawika committed
467
                if stderr is not None and len(items) > 1:
haileyschoelkopf's avatar
haileyschoelkopf committed
468
                    results[task_name][metric + "_stderr" + "," + key] = stderr(items)
469
                else:
lintangsutawika's avatar
lintangsutawika committed
470
                    results[task_name][metric + "_stderr" + "," + key] = "N/A"
Fabrizio Milo's avatar
Fabrizio Milo committed
471

lintangsutawika's avatar
lintangsutawika committed
472
        if bool(results):
473
            for group, task_list in reversed(task_hierarchy.items()):
474
475
476
477
478
479
                if task_list == []:
                    total_size = results[group]["samples"]
                else:
                    total_size = 0

                    for task in task_list:
lintangsutawika's avatar
lintangsutawika committed
480
481
482
483
                        print("###")
                        print(task)
                        print(metrics)
                        print("###")
Lintang Sutawika's avatar
Lintang Sutawika committed
484
485
486
487
                        metrics = results[task].copy()

                        if "alias" in metrics:
                            metrics.pop("alias")
488
489
490
491

                        # TODO: There should be a way for users
                        #       to toggle between weighted and
                        #       unweighted averaging
lintangsutawika's avatar
lintangsutawika committed
492
493
494
495
                        if weight_by_size:
                            current_size = metrics.pop("samples")
                        else:
                            current_size = 1
496
497
498
499
500
501
502

                        all_stderr = []
                        for metric in [
                            key for key in metrics.keys() if "_stderr" not in key
                        ]:
                            stderr = "_stderr,".join(metric.split(","))
                            stderr_score = results[task][stderr]
lintangsutawika's avatar
lintangsutawika committed
503
                            var_score = stderr_score**2
504
505
506
507
508
509
510
511
512
513
514
515
                            metric_score = results[task][metric]

                            all_stderr.append(stderr)

                            if metric in results[group]:
                                results[group][metric] = (
                                    results[group][metric] * total_size
                                    + metric_score * current_size
                                ) / (total_size + current_size)
                                # $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
                                results[group][stderr] = (
                                    (total_size - 1) * results[group][stderr]
lintangsutawika's avatar
lintangsutawika committed
516
                                    + (current_size - 1) * var_score
517
518
519
520
521
                                ) / (
                                    total_size + current_size - 1
                                ) + total_size * current_size / (
                                    (total_size + current_size)
                                    * (total_size + current_size - 1)
522
                                ) * (results[group][metric] - metric_score) ** 2
523
524
                            else:
                                results[group][metric] = metric_score
lintangsutawika's avatar
lintangsutawika committed
525
                                results[group][stderr] = var_score
526
527
528
529
530

                        total_size += current_size

                    for stderr in all_stderr:
                        results[group][stderr] = np.sqrt(results[group][stderr])
lintangsutawika's avatar
lintangsutawika committed
531

532
                results[group]["samples"] = total_size
lintangsutawika's avatar
lintangsutawika committed
533

Lintang Sutawika's avatar
Lintang Sutawika committed
534
        def print_tasks(task_hierarchy, results, tab=0):
535
536
537
            results_agg = collections.defaultdict(dict)
            groups_agg = collections.defaultdict(dict)

Lintang Sutawika's avatar
Lintang Sutawika committed
538
539
            (group_name, task_list), *_ = task_hierarchy.items()
            task_list = sorted(task_list)
540

Lintang Sutawika's avatar
Lintang Sutawika committed
541
542
543
544
            results_agg[group_name] = results[group_name].copy()
            # results_agg[group_name]["tab"] = tab
            if "samples" in results_agg[group_name]:
                results_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
545

Lintang Sutawika's avatar
Lintang Sutawika committed
546
            tab_string = " " * tab + "- " if tab > 0 else ""
lintangsutawika's avatar
lintangsutawika committed
547

Lintang Sutawika's avatar
Lintang Sutawika committed
548
549
550
551
            if "alias" in results_agg[group_name]:
                results_agg[group_name]["alias"] = (
                    tab_string + results_agg[group_name]["alias"]
                )
lintangsutawika's avatar
lintangsutawika committed
552
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
553
                results_agg[group_name]["alias"] = tab_string + group_name
lintangsutawika's avatar
lintangsutawika committed
554

Lintang Sutawika's avatar
Lintang Sutawika committed
555
556
557
558
559
            if len(task_list) > 0:
                groups_agg[group_name] = results[group_name].copy()
                # groups_agg[group_name]["tab"] = tab
                if "samples" in groups_agg[group_name]:
                    groups_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
560

Lintang Sutawika's avatar
Lintang Sutawika committed
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
                if "alias" in groups_agg[group_name]:
                    groups_agg[group_name]["alias"] = (
                        tab_string + groups_agg[group_name]["alias"]
                    )
                else:
                    groups_agg[group_name]["alias"] = tab_string + group_name

                for task_name in task_list:
                    if task_name in task_hierarchy:
                        _task_hierarchy = {
                            **{task_name: task_hierarchy[task_name]},
                            **task_hierarchy,
                        }
                    else:
                        _task_hierarchy = {
                            **{task_name: []},
                            **task_hierarchy,
                        }

                    _results_agg, _groups_agg = print_tasks(
                        _task_hierarchy, results, tab + 1
                    )
                    results_agg = {**results_agg, **_results_agg}
                    groups_agg = {**groups_agg, **_groups_agg}

            return results_agg, groups_agg

        results_agg = collections.defaultdict(dict)
        groups_agg = collections.defaultdict(dict)
        all_tasks_list = list(task_hierarchy.keys())
        left_tasks_list = []
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
                k: v for k, v in task_hierarchy.items() if k in left_tasks_list
            }
            _results_agg, _groups_agg = print_tasks(_task_hierarchy, results)

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
605

606
        for group_name, task_list in task_hierarchy.items():
Lintang Sutawika's avatar
Lintang Sutawika committed
607
608
            if task_list != []:
                num_fewshot[group_name] = num_fewshot[task_list[0]]
609

610
        results_dict = {
611
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
612
            **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
613
614
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
615
            "n-shot": dict(sorted(num_fewshot.items())),
616
        }
617
618
619
620
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
621

622
623
    else:
        return None