__init__.py 2.04 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from .aggregation import *
from .metric import *

from lm_eval.api.metrics import bootstrap_stderr, mean_stderr, acc_all_stderr
from lm_eval.api.register import (
    metric_registry,
    aggregation_registry,
    higher_is_better_registry,
    output_type_registry,
    default_aggregation_registry,
)

METRIC_REGISTRY = metric_registry
OUTPUT_TYPE_REGISTRY = output_type_registry
AGGREGATION_REGISTRY = aggregation_registry
DEFAULT_AGGREGATION_REGISTRY = default_aggregation_registry
HIGHER_IS_BETTER_REGISTRY = higher_is_better_registry

DEFAULT_METRIC_REGISTRY = {
    "loglikelihood": [
        "perplexity",
        "acc",
    ],
    "loglikelihood_rolling": ["word_perplexity", "byte_perplexity", "bits_per_byte"],
    "multiple_choice": [
        "acc",
    ],
    "greedy_until": ["exact_match"],
}


def get_metric(name):

    try:
        return METRIC_REGISTRY[name]
    except KeyError:
        # TODO: change this print to logging?
        print(
            f"Could not find registered metric '{name}' in lm-eval, \
searching in HF Evaluate library..."
        )
        try:
            import evaluate

            metric_object = evaluate.load(name)
            return metric_object.compute
        except Exception:
            raise Warning(
                "{} not found in the evaluate library!".format(name),
                "Please check https://huggingface.co/evaluate-metric",
            )


def get_aggregation(name):

    try:
        return AGGREGATION_REGISTRY[name]
    except KeyError:
        raise Warning(
            "{} not a registered aggregation metric!".format(name),
        )


def stderr_for_metric(metric, bootstrap_iters):
    bootstrappable = [
        "median",
        "matthews_corrcoef",
        "f1_score",
        "perplexity",
        "bleu",
        "chrf",
        "ter",
    ]

    if metric in bootstrappable:
        return lambda x: bootstrap_stderr(
            METRIC_REGISTRY[metric], x, iters=bootstrap_iters
        )

    stderr = {"mean": mean_stderr, "acc_all": acc_all_stderr}

    return stderr.get(metric, None)