evaluator.py 25.3 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
import random
Leo Gao's avatar
Leo Gao committed
2
import itertools
lintangsutawika's avatar
lintangsutawika committed
3
4
import collections

5
6
import torch

7
import numpy as np
lintangsutawika's avatar
lintangsutawika committed
8
9

import lm_eval.api
10
import lm_eval.tasks
lintangsutawika's avatar
lintangsutawika committed
11
import lm_eval.models
lintangsutawika's avatar
lintangsutawika committed
12
import lm_eval.api.metrics
lintangsutawika's avatar
lintangsutawika committed
13
import lm_eval.api.registry
lintangsutawika's avatar
lintangsutawika committed
14

lintangsutawika's avatar
lintangsutawika committed
15
16
17
18
from lm_eval.utils import (
    positional_deprecated,
    run_task_tests,
    get_git_commit_hash,
lintangsutawika's avatar
lintangsutawika committed
19
    simple_parse_args_string,
lintangsutawika's avatar
lintangsutawika committed
20
    eval_logger,
lintangsutawika's avatar
lintangsutawika committed
21
)
22

Fabrizio Milo's avatar
Fabrizio Milo committed
23

24
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
25
26
27
def simple_evaluate(
    model,
    model_args=None,
28
    tasks=None,
29
    num_fewshot=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
30
    batch_size=None,
31
    max_batch_size=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
32
    device=None,
haileyschoelkopf's avatar
haileyschoelkopf committed
33
    use_cache=None,
Fabrizio Milo's avatar
Fabrizio Milo committed
34
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
35
36
    bootstrap_iters: int = 100000,
    check_integrity: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
37
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
38
39
    write_out: bool = False,
    log_samples: bool = True,
lintangsutawika's avatar
lintangsutawika committed
40
    gen_kwargs: str = None,
Baber Abbasi's avatar
Baber Abbasi committed
41
    predict_only: bool = False,
Fabrizio Milo's avatar
Fabrizio Milo committed
42
):
43
    """Instantiate and evaluate a model on a list of tasks.
44

45
46
47
    :param model: Union[str, LM]
        Name of model or LM object, see lm_eval.models.get_model
    :param model_args: Optional[str]
Fabrizio Milo's avatar
Fabrizio Milo committed
48
        String arguments for each model class, see LM.create_from_arg_string.
49
50
        Ignored if `model` argument is a LM object.
    :param tasks: list[Union[str, Task]]
Leo Gao's avatar
Leo Gao committed
51
        List of task names or Task objects. Task objects will be taken to have name task.EVAL_HARNESS_NAME if defined and type(task).__name__ otherwise.
52
53
    :param num_fewshot: int
        Number of examples in few-shot context
54
    :param batch_size: int or str, optional
55
        Batch size for model
56
57
    :param max_batch_size: int, optional
        Maximal batch size to try with automatic batch size detection
58
    :param device: str, optional
59
        PyTorch device (e.g. "cpu" or "cuda:0") for running models
haileyschoelkopf's avatar
haileyschoelkopf committed
60
61
    :param use_cache: str, optional
        A path to a sqlite db file for caching model responses. `None` if not caching.
62
63
    :param limit: int or float, optional
        Limit the number of examples per task (only use this for testing), If <1, limit is a percentage of the total number of examples.
64
65
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
Stephen Hogg's avatar
Stephen Hogg committed
66
67
    :param check_integrity: bool
        Whether to run the relevant part of the test suite for the tasks
68
    :param write_out: bool
69
70
71
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
72
73
74
    :param gen_kwargs: str
        String arguments for model generation
        Ignored for all tasks with loglikelihood output_type
Baber Abbasi's avatar
Baber Abbasi committed
75
76
77
    :param predict_only: bool
        If true only model outputs will be generated and returned. Metrics will not be evaluated

78
    :return
79
        Dictionary of results
80
    """
81
    random.seed(0)
82
    np.random.seed(1234)
83
84
85
    torch.manual_seed(
        1234
    )  # TODO: this may affect training runs that are run with evaluation mid-run.
86

87
88
    if tasks is None:
        tasks = []
89
90
91
    assert (
        tasks != []
    ), "No tasks specified, or no tasks found. Please verify the task names."
92

lintangsutawika's avatar
lintangsutawika committed
93
94
    if gen_kwargs is not None:
        gen_kwargs = simple_parse_args_string(gen_kwargs)
lintangsutawika's avatar
udate  
lintangsutawika committed
95
        eval_logger.warning(
Baber Abbasi's avatar
Baber Abbasi committed
96
            "generation_kwargs specified through cli, these settings will update set parameters in yaml tasks. Ensure 'do_sample=True' for non-greedy decoding!"
lintangsutawika's avatar
udate  
lintangsutawika committed
97
        )
lintangsutawika's avatar
lintangsutawika committed
98
99
100
        if gen_kwargs == "":
            gen_kwargs = None

101
    if isinstance(model, str):
Fabrizio Milo's avatar
Fabrizio Milo committed
102
103
        if model_args is None:
            model_args = ""
lintangsutawika's avatar
lintangsutawika committed
104
        lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
lintangsutawika's avatar
lintangsutawika committed
105
106
107
108
109
110
            model_args,
            {
                "batch_size": batch_size,
                "max_batch_size": max_batch_size,
                "device": device,
            },
Fabrizio Milo's avatar
Fabrizio Milo committed
111
        )
112
    else:
113
        assert isinstance(model, lm_eval.api.model.LM)
114
        lm = model
115

haileyschoelkopf's avatar
haileyschoelkopf committed
116
117
118
119
120
121
122
    if use_cache is not None:
        print(f"Using cache at {use_cache + '_rank' + str(lm.rank) + '.db'}")
        lm = lm_eval.api.model.CachingLM(
            lm,
            use_cache
            # each rank receives a different cache db.
            # necessary to avoid multiple writes to cache at once
123
124
125
            + "_rank"
            + str(lm.rank)
            + ".db",
haileyschoelkopf's avatar
haileyschoelkopf committed
126
127
        )

128
129
    task_dict = lm_eval.tasks.get_task_dict(tasks)
    for task_name in task_dict.keys():
lintangsutawika's avatar
lintangsutawika committed
130
        task_obj = task_dict[task_name]
131
        if isinstance(task_obj, tuple):
lintangsutawika's avatar
lintangsutawika committed
132
            group, task_obj = task_obj
133
134
            if task_obj is None:
                continue
lintangsutawika's avatar
lintangsutawika committed
135

Baber Abbasi's avatar
Baber Abbasi committed
136
137
138
139
140
141
142
143
144
145
146
147
148
        if task_obj.get_config("output_type") == "generate_until":
            if gen_kwargs is not None:
                task_obj.override_config(
                    key="generation_kwargs", value=gen_kwargs, update=True
                )

            if predict_only:
                log_samples = True
                eval_logger.info(
                    f"Processing {task_name} in output-only mode. Metrics will not be calculated!"
                )
                # we have to change the class properties post-hoc. This is pretty hacky.
                task_obj.override_metric(metric_name="bypass")
149

150
        if num_fewshot is not None:
Baber Abbasi's avatar
Baber Abbasi committed
151
            if (default_num_fewshot := task_obj.get_config("num_fewshot")) == 0:
152
153
154
                eval_logger.info(
                    f"num_fewshot has been set to 0 for {task_name} in its config. Manual configuration will be ignored."
                )
155
            else:
Baber Abbasi's avatar
Baber Abbasi committed
156
157
158
159
                eval_logger.warning(
                    f"Overwriting default num_fewshot of {task_name} from {default_num_fewshot} to {num_fewshot}"
                )
                task_obj.override_config(key="num_fewshot", value=num_fewshot)
Jonathan Tow's avatar
Merge  
Jonathan Tow committed
160

Stephen Hogg's avatar
Stephen Hogg committed
161
    if check_integrity:
162
        run_task_tests(task_list=tasks)
Stephen Hogg's avatar
Stephen Hogg committed
163

164
165
166
167
    results = evaluate(
        lm=lm,
        task_dict=task_dict,
        limit=limit,
Niklas Muennighoff's avatar
Niklas Muennighoff committed
168
        bootstrap_iters=bootstrap_iters,
Fabrizio Milo's avatar
Fabrizio Milo committed
169
        decontamination_ngrams_path=decontamination_ngrams_path,
170
        write_out=write_out,
171
        log_samples=log_samples,
172
    )
173

174
    if lm.rank == 0:
175
176
177
178
179
180
181
        if isinstance(model, str):
            model_name = model
        elif hasattr(model, "config") and hasattr(model.config, "_name_or_path"):
            model_name = model.config._name_or_path
        else:
            model_name = type(model).__name__

182
183
        # add info about the model and few shot config
        results["config"] = {
184
            "model": model_name,
185
186
            "model_args": model_args,
            "batch_size": batch_size,
lintangsutawika's avatar
lintangsutawika committed
187
188
189
            "batch_sizes": list(lm.batch_sizes.values())
            if hasattr(lm, "batch_sizes")
            else [],
190
            "device": device,
haileyschoelkopf's avatar
haileyschoelkopf committed
191
            "use_cache": use_cache,
192
193
            "limit": limit,
            "bootstrap_iters": bootstrap_iters,
lintangsutawika's avatar
lintangsutawika committed
194
            "gen_kwargs": gen_kwargs,
195
        }
196
        results["git_hash"] = get_git_commit_hash()
197
198
199
        return results
    else:
        return None
200

Leo Gao's avatar
Leo Gao committed
201

202
decontaminate_suffix = "_decontaminate"
Leo Gao's avatar
Leo Gao committed
203

Fabrizio Milo's avatar
Fabrizio Milo committed
204

205
@positional_deprecated
Fabrizio Milo's avatar
Fabrizio Milo committed
206
207
208
209
def evaluate(
    lm,
    task_dict,
    limit=None,
Ethan Smith's avatar
Ethan Smith committed
210
    bootstrap_iters: int = 100000,
Fabrizio Milo's avatar
Fabrizio Milo committed
211
    decontamination_ngrams_path=None,
Ethan Smith's avatar
Ethan Smith committed
212
213
    write_out: bool = False,
    log_samples: bool = True,
Fabrizio Milo's avatar
Fabrizio Milo committed
214
):
215
216
217
218
219
    """Instantiate and evaluate a model on a list of tasks.

    :param lm: obj
        Language Model
    :param task_dict: dict[str, Task]
haileyschoelkopf's avatar
haileyschoelkopf committed
220
        Dictionary of tasks. Tasks will be taken to have name type(task).config.task .
221
222
223
224
    :param limit: int, optional
        Limit the number of examples per task (only use this for testing)
    :param bootstrap_iters:
        Number of iterations for bootstrap statistics
225
    :param write_out: bool
226
227
228
        If True, write out an example document and model input for checking task integrity
    :param log_samples: bool
        If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis
229
230
231
    :return
        Dictionary of results
    """
232

lintangsutawika's avatar
lintangsutawika committed
233
    # decontaminate = decontamination_ngrams_path is not None
234

Baber Abbasi's avatar
Baber Abbasi committed
235
236
237
238
239
240
241
242
    for task_name, task in task_dict.items():
        if isinstance(task, tuple):
            _, task = task
        if not log_samples:
            assert (
                "bypass" not in getattr(task, "_metric_fn_list", {}).keys()
            ), f"log_samples must be True for 'bypass' only tasks: {task_name}"

243
    # stores the final result for each task, for each metric/filter pair.
Leo Gao's avatar
Leo Gao committed
244
    results = collections.defaultdict(dict)
245
    # Tracks each task's version.
Leo Gao's avatar
Leo Gao committed
246
    versions = collections.defaultdict(dict)
247
    # Tracks the YAML configs of all chosen tasks.
248
    configs = collections.defaultdict(dict)
249
    # logs info about each document evaluated.
lintangsutawika's avatar
lintangsutawika committed
250
    samples = collections.defaultdict(list)
251
    # tracks all Instances/requests a model must generate output on.
Leo Gao's avatar
Leo Gao committed
252
    requests = collections.defaultdict(list)
253
    # Aggregated task scores presented with groups
254
    results_agg = collections.defaultdict(dict)
255
    # Aggregated groups scores only
lintangsutawika's avatar
lintangsutawika committed
256
    groups_agg = collections.defaultdict(dict)
257
258
    # stores the amount to pad out reqs per req. type so that
    # number of fwd passes per distributed rank is equal
259
    padding_requests = collections.defaultdict(int)
lintangsutawika's avatar
lintangsutawika committed
260
    # store the hierarchy to do proper ordering
lintangsutawika's avatar
lintangsutawika committed
261
    task_hierarchy = collections.defaultdict(list)
262
263
    # store num-fewshot value per task
    num_fewshot = collections.defaultdict(int)
264

265
    # get lists of each type of request
266
    for task_name, task in task_dict.items():
267
        if isinstance(task, tuple):
lintangsutawika's avatar
lintangsutawika committed
268
269
            group_name, task = task
            task_hierarchy[group_name].append(task_name)
270
            versions[group_name] = "N/A"
lintangsutawika's avatar
lintangsutawika committed
271

272
        else:
273
            group_name = None
lintangsutawika's avatar
lintangsutawika committed
274
275
276
277
            task_hierarchy[task_name] = []

        if task is None:
            continue
278

Leo Gao's avatar
Leo Gao committed
279
        versions[task_name] = task.VERSION
haileyschoelkopf's avatar
haileyschoelkopf committed
280
281
        configs[task_name] = dict(task.dump_config())

282
        if "num_fewshot" in configs[task_name]:
283
284
285
286
            if configs[task_name]["metadata"]:
                n_shot = configs[task_name]["metadata"].get("num_fewshot", None)
            if not n_shot:
                n_shot = configs[task_name]["num_fewshot"]
287
        else:
288
            n_shot = 0 # TODO: is this always right?
289
290
        num_fewshot[task_name] = n_shot

lintangsutawika's avatar
lintangsutawika committed
291
        if "task_alias" in configs[task_name]:
Lintang Sutawika's avatar
Lintang Sutawika committed
292
            results[task_name]["alias"] = configs[task_name]["task_alias"]
lintangsutawika's avatar
lintangsutawika committed
293

lintangsutawika's avatar
format  
lintangsutawika committed
294
295
        if (
            ("group_alias" in configs[task_name])
Lintang Sutawika's avatar
Lintang Sutawika committed
296
            and (group_name not in results)
lintangsutawika's avatar
format  
lintangsutawika committed
297
            and (group_name is not None)
lintangsutawika's avatar
lintangsutawika committed
298
        ):
Lintang Sutawika's avatar
Lintang Sutawika committed
299
            results[group_name]["alias"] = configs[task_name]["group_alias"]
lintangsutawika's avatar
lintangsutawika committed
300

Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
301
        if limit is not None:
302
303
304
305
306
307
            if task.has_test_docs():
                task_docs = task.test_docs()
            elif task.has_validation_docs():
                task_docs = task.validation_docs()
            else:
                raise RuntimeError("Task has neither test_docs nor validation_docs")
308
            limit = int(len(task_docs) * limit) if limit < 1.0 else int(limit)
309

310
311
        task.build_all_requests(limit=limit, rank=lm.rank, world_size=lm.world_size)

312
        eval_logger.debug(
haileyschoelkopf's avatar
haileyschoelkopf committed
313
314
315
316
317
318
            f"Task: {task_name}; number of requests on this rank: {len(task.instances)}"
        )

        if write_out:
            for inst in task.instances:
                # print the prompt for the first few documents
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
319
320
                if inst.doc_id < 1:
                    eval_logger.info(
haileyschoelkopf's avatar
haileyschoelkopf committed
321
322
                        f"Task: {task_name}; document {inst.doc_id}; context prompt (starting on next line):\
\n{inst.args[0]}\n(end of prompt on previous line)\ntarget string or answer choice index (starting on next line):\n{task.doc_to_target(inst.doc)}\n(end of target on previous line)"
haileyschoelkopf's avatar
haileyschoelkopf committed
323
                    )
haileyschoelkopf's avatar
haileyschoelkopf committed
324
                    eval_logger.info(f"Request: {str(inst)}")
haileyschoelkopf's avatar
haileyschoelkopf committed
325

326
        # aggregate Instances by LM method requested to get output.
lintangsutawika's avatar
lintangsutawika committed
327
328
329
        for instance in task.instances:
            reqtype = instance.request_type
            requests[reqtype].append(instance)
330
331

        if lm.world_size > 1:
332
333
334
335
            instances_rnk = torch.tensor(len(task._instances), device=lm.device)
            gathered_item = (
                lm.accelerator.gather(instances_rnk).cpu().detach().numpy().tolist()
            )
336

337
            # compute number of pseudobatches to pad with (FSDP/DDP require even batches among ranks)
338
            numpad = max(gathered_item) - gathered_item[lm.rank]
339
            padding_requests[task.OUTPUT_TYPE] += numpad
340

341
    ### Run LM on inputs, get all outputs ###
Leo Gao's avatar
Leo Gao committed
342
343
    # execute each type of request
    for reqtype, reqs in requests.items():
344
        eval_logger.info(f"Running {reqtype} requests")
345
346
347
348
        # create `K` copies of each request `req` based off `K = req.repeats`
        cloned_reqs = []
        for req in reqs:
            cloned_reqs.extend([req] * req.repeats)
lintangsutawika's avatar
lintangsutawika committed
349

350
351
        if (lm.world_size > 1) and (padding_requests[reqtype] > 0):
            for _ in range(padding_requests[reqtype]):
352
353
                cloned_reqs.extend([req] * req.repeats)

354
355
356
357
358
359
360
        # run requests through model
        resps = getattr(lm, reqtype)(cloned_reqs)

        # put responses from model into a list of length K for each request.
        for x, req in zip(resps, cloned_reqs):
            req.resps.append(x)

361
362
        if lm.world_size > 1:
            lm.accelerator.wait_for_everyone()
363

364
365
366
    ### Postprocess outputs ###
    # TODO: del model here, maybe (idea: allow user to specify device of e.g. reward model separately)
    for task_name, task in task_dict.items():
367
        if isinstance(task, tuple):
368
            group, task = task
369
370
            if task is None:
                continue
371
372
373
        task.apply_filters()

    ### Collect values of metrics on all datapoints ###
Leo Gao's avatar
Leo Gao committed
374
375
376
    vals = collections.defaultdict(list)

    # unpack results and sort back in order and return control to Task
377
    for task_name, task in task_dict.items():
378
        if isinstance(task, tuple):
379
            group, task = task
380
381
            if task is None:
                continue
haileyschoelkopf's avatar
haileyschoelkopf committed
382
383
        # TODO: make it possible to use a different metric per filter
        # iterate over different filters used
384
        for key in task.instances[0].filtered_resps.keys():
385
386
387
388
            doc_iterator = (
                itertools.islice(
                    enumerate(task.test_docs()), lm.rank, limit, lm.world_size
                )
lintangsutawika's avatar
lintangsutawika committed
389
                if task.has_test_docs()
390
391
392
393
                else itertools.islice(
                    enumerate(task.validation_docs()), lm.rank, limit, lm.world_size
                )
            )
394
            for doc_id, doc in doc_iterator:
395
396
                # subset instances to only this document id ; sort by idx
                requests = list(filter(lambda x: x.doc_id == doc_id, task.instances))
397
                requests.sort(key=lambda x: x.idx)
lintangsutawika's avatar
lintangsutawika committed
398
399
400
                metrics = task.process_results(
                    doc, [req.filtered_resps[key] for req in requests]
                )
401
402
403
404
405
406
407
408
409
410
411
412
                if log_samples:
                    target = task.doc_to_target(doc)
                    example = {
                        "doc_id": doc_id,
                        "doc": doc,
                        "target": target,
                        "arguments": [req.args for req in requests],
                        "resps": [req.resps for req in requests],
                        "filtered_resps": [req.filtered_resps[key] for req in requests],
                    }
                    example.update(metrics)
                    samples[task_name].append(example)
413
414
415
                for metric, value in metrics.items():
                    vals[(task_name, key, metric)].append(value)

416
    if lm.world_size > 1:
417
        # if multigpu, then gather data across all ranks
418
419
420
421
422
423
424
425
        # first gather logged samples across all ranks
        for task_name, task_samples in list(samples.items()):
            full_samples = [None] * lm.world_size
            torch.distributed.all_gather_object(full_samples, task_samples)

            samples[task_name] = list(itertools.chain.from_iterable(full_samples))

        # then collect metrics across all ranks
426
427
        vals_torch = collections.defaultdict(list)
        for (task_name, key, metric), items in vals.items():
428
            numitem = 0
429
            if isinstance(items[0], tuple):
430
431
                numitem = len(items[0])

Lintang Sutawika's avatar
Lintang Sutawika committed
432
            if isinstance(items[0], (str, list, tuple)):
433
434
435
                # handle the string case
                gathered_items = [None] * lm.accelerator.num_processes
                torch.distributed.all_gather_object(gathered_items, items)
436

437
                gathered_item = list(itertools.chain.from_iterable(gathered_items))
438
            else:
439
440
441
442
443
444
445
446
447
448
                # distributed gather requires all ranks to have same dimensions
                # so we pad out with float32 min value
                pad_value = torch.finfo(torch.float32).min
                metrics_tensor = torch.tensor(items, device=lm.device)

                original_dtype = metrics_tensor.dtype  # store original dtype
                torch_device_tensor = lm.accelerator.pad_across_processes(
                    metrics_tensor.to(torch.float32), pad_index=pad_value
                )
                gathered_item = lm.accelerator.gather(torch_device_tensor)
449

450
451
452
453
454
455
456
457
458
459
460
                if numitem > 0:
                    gathered_filtered = gathered_item[gathered_item[:, 0] != pad_value]
                else:
                    gathered_filtered = gathered_item[gathered_item != pad_value]

                gathered_item = (
                    gathered_filtered.to(original_dtype).cpu().detach().numpy().tolist()
                )
                # reconvert if we were passed a tuple of values
                if numitem > 0:
                    gathered_item = [tuple(g) for g in gathered_item]
461

462
463
            if lm.rank == 0:
                vals_torch[(task_name, key, metric)] = gathered_item
464

465
        vals = vals_torch
466

467
    if lm.rank == 0:
lintangsutawika's avatar
lintangsutawika committed
468

469
470
471
472
        ### Aggregate results over all datapoints ###
        # aggregate results ; run bootstrap CIs
        for (task_name, key, metric), items in vals.items():
            task = task_dict[task_name]
lintangsutawika's avatar
lintangsutawika committed
473
474
            metric_key = metric + "," + key

475
            if isinstance(task, tuple):
lintangsutawika's avatar
lintangsutawika committed
476
477
478
479
                group_name, task = task
            else:
                group_name = None

480
            agg_fn = task.aggregation()[metric]
481
482
            results[task_name][metric_key] = agg_fn(items)
            results[task_name]["samples"] = len(items)
lintangsutawika's avatar
lintangsutawika committed
483

484
485
            # hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
            # so we run them less iterations. still looking for a cleaner way to do this
haileyschoelkopf's avatar
haileyschoelkopf committed
486
            if bootstrap_iters > 0:
haileyschoelkopf's avatar
haileyschoelkopf committed
487
488
                stderr = lm_eval.api.metrics.stderr_for_metric(
                    metric=task.aggregation()[metric],
haileyschoelkopf's avatar
haileyschoelkopf committed
489
                    bootstrap_iters=min(bootstrap_iters, 100)
haileyschoelkopf's avatar
haileyschoelkopf committed
490
491
492
                    if metric in ["bleu", "chrf", "ter"]
                    else bootstrap_iters,
                )
493

lintangsutawika's avatar
lintangsutawika committed
494
                if stderr is not None and len(items) > 1:
haileyschoelkopf's avatar
haileyschoelkopf committed
495
                    results[task_name][metric + "_stderr" + "," + key] = stderr(items)
496
                else:
lintangsutawika's avatar
lintangsutawika committed
497
                    results[task_name][metric + "_stderr" + "," + key] = "N/A"
Fabrizio Milo's avatar
Fabrizio Milo committed
498

lintangsutawika's avatar
lintangsutawika committed
499
        if bool(results):
500
            for group, task_list in reversed(task_hierarchy.items()):
501
                if task_list == []:
Baber Abbasi's avatar
Baber Abbasi committed
502
503
                    # TODO: No samples when bypass
                    total_size = results[group].get("samples", 999)
504
505
506
507
                else:
                    total_size = 0

                    for task in task_list:
Lintang Sutawika's avatar
Lintang Sutawika committed
508
509
510
511
                        metrics = results[task].copy()

                        if "alias" in metrics:
                            metrics.pop("alias")
512
513
514
515
516
517
518
519
520
521
522
523
524
525

                        current_size = metrics.pop("samples")
                        # TODO: There should be a way for users
                        #       to toggle between weighted and
                        #       unweighted averaging
                        # For unweighted averaging, use:
                        #     current_size = 1

                        all_stderr = []
                        for metric in [
                            key for key in metrics.keys() if "_stderr" not in key
                        ]:
                            stderr = "_stderr,".join(metric.split(","))
                            stderr_score = results[task][stderr]
526
527
528
529
530
                            if stderr_score == "N/A":
                                var_score = "N/A"
                            else:
                                var_score = stderr_score**2
                                all_stderr.append(stderr)
531

532
                            metric_score = results[task][metric]
533
534
535
536
537
538
539

                            if metric in results[group]:
                                results[group][metric] = (
                                    results[group][metric] * total_size
                                    + metric_score * current_size
                                ) / (total_size + current_size)
                                # $$s_z^2 = \frac{(n-1) s_x^2 + (m-1) s_y^2}{n+m-1} + \frac{nm(\bar x - \bar y)^2}{(n+m)(n+m-1)}.$$
540
                                if var_score == "N/A" or results[group][stderr] == "N/A":
541
542
543
544
545
546
547
548
549
550
551
552
553
                                    results[group][stderr] = "N/A"
                                else:
                                    results[group][stderr] = (
                                        (total_size - 1) * results[group][stderr]
                                        + (current_size - 1) * var_score
                                    ) / (
                                        total_size + current_size - 1
                                    ) + total_size * current_size / (
                                        (total_size + current_size)
                                        * (total_size + current_size - 1)
                                    ) * (
                                        results[group][metric] - metric_score
                                    ) ** 2
554
555
                            else:
                                results[group][metric] = metric_score
lintangsutawika's avatar
lintangsutawika committed
556
                                results[group][stderr] = var_score
557
558
559
560
561

                        total_size += current_size

                    for stderr in all_stderr:
                        results[group][stderr] = np.sqrt(results[group][stderr])
lintangsutawika's avatar
lintangsutawika committed
562

563
                results[group]["samples"] = total_size
lintangsutawika's avatar
lintangsutawika committed
564

Lintang Sutawika's avatar
Lintang Sutawika committed
565
        def print_tasks(task_hierarchy, results, tab=0):
566
567
568
            results_agg = collections.defaultdict(dict)
            groups_agg = collections.defaultdict(dict)

Lintang Sutawika's avatar
Lintang Sutawika committed
569
570
            (group_name, task_list), *_ = task_hierarchy.items()
            task_list = sorted(task_list)
571

Lintang Sutawika's avatar
Lintang Sutawika committed
572
573
574
575
            results_agg[group_name] = results[group_name].copy()
            # results_agg[group_name]["tab"] = tab
            if "samples" in results_agg[group_name]:
                results_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
576

Lintang Sutawika's avatar
Lintang Sutawika committed
577
            tab_string = " " * tab + "- " if tab > 0 else ""
lintangsutawika's avatar
lintangsutawika committed
578

Lintang Sutawika's avatar
Lintang Sutawika committed
579
580
581
582
            if "alias" in results_agg[group_name]:
                results_agg[group_name]["alias"] = (
                    tab_string + results_agg[group_name]["alias"]
                )
lintangsutawika's avatar
lintangsutawika committed
583
            else:
Lintang Sutawika's avatar
Lintang Sutawika committed
584
                results_agg[group_name]["alias"] = tab_string + group_name
lintangsutawika's avatar
lintangsutawika committed
585

Lintang Sutawika's avatar
Lintang Sutawika committed
586
587
588
589
590
            if len(task_list) > 0:
                groups_agg[group_name] = results[group_name].copy()
                # groups_agg[group_name]["tab"] = tab
                if "samples" in groups_agg[group_name]:
                    groups_agg[group_name].pop("samples")
lintangsutawika's avatar
lintangsutawika committed
591

Lintang Sutawika's avatar
Lintang Sutawika committed
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
                if "alias" in groups_agg[group_name]:
                    groups_agg[group_name]["alias"] = (
                        tab_string + groups_agg[group_name]["alias"]
                    )
                else:
                    groups_agg[group_name]["alias"] = tab_string + group_name

                for task_name in task_list:
                    if task_name in task_hierarchy:
                        _task_hierarchy = {
                            **{task_name: task_hierarchy[task_name]},
                            **task_hierarchy,
                        }
                    else:
                        _task_hierarchy = {
                            **{task_name: []},
                            **task_hierarchy,
                        }

                    _results_agg, _groups_agg = print_tasks(
                        _task_hierarchy, results, tab + 1
                    )
                    results_agg = {**results_agg, **_results_agg}
                    groups_agg = {**groups_agg, **_groups_agg}

            return results_agg, groups_agg

        results_agg = collections.defaultdict(dict)
        groups_agg = collections.defaultdict(dict)
        all_tasks_list = list(task_hierarchy.keys())
        left_tasks_list = []
        while True:
            add_tasks_list = list(k for k in results_agg.keys())
            left_tasks_list = sorted(list(set(all_tasks_list) - set(add_tasks_list)))
            if len(left_tasks_list) == 0:
                break

            _task_hierarchy = {
                k: v for k, v in task_hierarchy.items() if k in left_tasks_list
            }
            _results_agg, _groups_agg = print_tasks(_task_hierarchy, results)

            results_agg = {**results_agg, **_results_agg}
            groups_agg = {**groups_agg, **_groups_agg}
lintangsutawika's avatar
lintangsutawika committed
636

637
        for group_name, task_list in task_hierarchy.items():
Lintang Sutawika's avatar
Lintang Sutawika committed
638
            if task_list != []:
639
                num_fewshot[group_name] = num_fewshot[task_list[0]] # TODO: validate this
640

641
        results_dict = {
642
            "results": dict(results_agg.items()),
lintangsutawika's avatar
lintangsutawika committed
643
            **({"groups": dict(groups_agg.items())} if bool(groups_agg) else {}),
644
645
            "configs": dict(sorted(configs.items())),
            "versions": dict(sorted(versions.items())),
646
            "n-shot": dict(sorted(num_fewshot.items())),
647
        }
648
649
650
651
        if log_samples:
            results_dict["samples"] = dict(samples)

        return results_dict
Fabrizio Milo's avatar
Fabrizio Milo committed
652

653
654
    else:
        return None