registry.py 4.1 KB
Newer Older
1
import os
lintangsutawika's avatar
lintangsutawika committed
2
import logging
3
import evaluate
lintangsutawika's avatar
lintangsutawika committed
4
5
from functools import partial

6
from lm_eval.api.model import LM
lintangsutawika's avatar
lintangsutawika committed
7

8
eval_logger = logging.getLogger("lm-eval")
9
10
11

MODEL_REGISTRY = {}

lintangsutawika's avatar
lintangsutawika committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class HFEvaluateAdaptor:
    def __init__(self, name, **kwargs):

        self.name = name
        metric_object = evaluate.load(name)
        self.hf_evaluate_fn = partial(metric_object.compute, **kwargs)

    def __call__(self, items):
        refs = list(zip(*items))[0]
        preds = list(zip(*items))[1]

        return self.hf_evaluate_fn(
            references=refs,
            predictions=preds
            )[self.name]
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

def register_model(*names):
    # either pass a list or a single alias.
    # function receives them as a tuple of strings

    def decorate(cls):
        for name in names:
            assert issubclass(
                cls, LM
            ), f"Model '{name}' ({cls.__name__}) must extend LM class"

            assert (
                name not in MODEL_REGISTRY
            ), f"Model named '{name}' conflicts with existing model! Please register with a non-conflicting alias instead."

            MODEL_REGISTRY[name] = cls
        return cls

    return decorate


def get_model(model_name):
haileyschoelkopf's avatar
haileyschoelkopf committed
49
50
51
    try:
        return MODEL_REGISTRY[model_name]
    except KeyError:
52
53
54
        raise ValueError(
            f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
        )
55
56
57
58


TASK_REGISTRY = {}
GROUP_REGISTRY = {}
59
ALL_TASKS = set()
60
61
62
63
64
65
66
67
68
69
func2task_index = {}


def register_task(name):
    def decorate(fn):
        assert (
            name not in TASK_REGISTRY
        ), f"task named '{name}' conflicts with existing registered task!"

        TASK_REGISTRY[name] = fn
70
        ALL_TASKS.add(name)
71
72
73
74
75
76
77
78
79
80
81
82
83
        func2task_index[fn.__name__] = name
        return fn

    return decorate


def register_group(name):
    def decorate(fn):
        func_name = func2task_index[fn.__name__]
        if name in GROUP_REGISTRY:
            GROUP_REGISTRY[name].append(func_name)
        else:
            GROUP_REGISTRY[name] = [func_name]
84
            ALL_TASKS.add(name)
85
86
87
88
89
        return fn

    return decorate


lintangsutawika's avatar
lintangsutawika committed
90
METRIC_FUNCTION_REGISTRY = {}
91
92
93
HIGHER_IS_BETTER_REGISTRY = {}

DEFAULT_METRIC_REGISTRY = {
94
95
96
97
    "loglikelihood": [],
    "loglikelihood_rolling": [],
    "multiple_choice": [],
    "generate_until": [],
98
99
100
}


101
def register_metric(
lintangsutawika's avatar
lintangsutawika committed
102
    metric=None,
103
104
105
    higher_is_better=None,
    output_type=None,
):
106
107
108
    # TODO: do we want to enforce a certain interface to registered metrics?
    def decorate(fn):

lintangsutawika's avatar
lintangsutawika committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        if type(metric) == str:
            metric_list = [metric]
        elif type(metric) == list:
            metric_list = metric

        for _metric in metric_list:
            METRIC_FUNCTION_REGISTRY[_metric] = fn

            if higher_is_better is not None:
                HIGHER_IS_BETTER_REGISTRY[_metric] = higher_is_better

            if output_type is not None:
                if type(output_type) == str:
                    output_type_list = [output_type]
                elif type(output_type) == list:
                    output_type_list = output_type

                for _output_type in output_type_list:
                    DEFAULT_METRIC_REGISTRY[_output_type].append(_metric)

129
130
131
132
133
        return fn

    return decorate


134
def get_metric(name, hf_evaluate_metric=False, **kwargs):
135

136
    if not hf_evaluate_metric:
lintangsutawika's avatar
lintangsutawika committed
137
138
        if name in METRIC_FUNCTION_REGISTRY:
            return METRIC_FUNCTION_REGISTRY[name]
139
140
141
142
        else:
            eval_logger.warning(
                f"Could not find registered metric '{name}' in lm-eval, searching in HF Evaluate library..."
            )
Chris's avatar
Chris committed
143

144
    try:
lintangsutawika's avatar
lintangsutawika committed
145
        # from lm_eval.metrics import HFEvaluateAdaptor
146
        return HFEvaluateAdaptor(name, **kwargs)
147
148
149
    except Exception:
        eval_logger.error(
            f"{name} not found in the evaluate library! Please check https://huggingface.co/evaluate-metric",
150
151
152
        )


haileyschoelkopf's avatar
haileyschoelkopf committed
153
154
155
156
def is_higher_better(metric_name):
    try:
        return HIGHER_IS_BETTER_REGISTRY[metric_name]
    except KeyError:
157
158
159
        eval_logger.warning(
            f"higher_is_better not specified for metric '{metric_name}'!"
        )