samplers.py 8.61 KB
Newer Older
1
2
from __future__ import annotations

3
4
import logging
import warnings
5
from collections.abc import Iterable, Sequence
Yu Shi Jie's avatar
Yu Shi Jie committed
6
from functools import partial
7
from typing import TYPE_CHECKING, Any
Yu Shi Jie's avatar
Yu Shi Jie committed
8

9
10
11
import datasets


Baber Abbasi's avatar
Baber Abbasi committed
12
13
14
15
16
if TYPE_CHECKING:
    from random import Random

    from lm_eval.api.task import ConfigurableTask, Task

17
18
eval_logger = logging.getLogger("lm-eval")

Baber Abbasi's avatar
Baber Abbasi committed
19

haileyschoelkopf's avatar
haileyschoelkopf committed
20
class ContextSampler:
Baber Abbasi's avatar
Baber Abbasi committed
21
22
23
    def __init__(
        self,
        docs: list[dict],
24
25
26
        task: Task | ConfigurableTask,
        fewshot_indices: Iterable | None = None,
        rnd: Random | None = None,
Baber Abbasi's avatar
Baber Abbasi committed
27
    ) -> None:
28
        self.rnd = rnd
29
30
31
32
        if not self.rnd:
            raise ValueError(
                "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
            )
33
34
35
36

        self.task = task
        self.config = task._config

Lintang Sutawika's avatar
Lintang Sutawika committed
37
38
        self.target_delimiter = self.config.target_delimiter
        self.fewshot_delimiter = self.config.fewshot_delimiter
39

Yu Shi Jie's avatar
Yu Shi Jie committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        if (
            self.config.fewshot_config is not None
            and self.config.fewshot_config.get("doc_to_text", None) is not None
        ):
            self.doc_to_text = partial(
                self.task.doc_to_text,
                doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
            )
        else:
            self.doc_to_text = self.task.doc_to_text

        if (
            self.config.fewshot_config is not None
            and self.config.fewshot_config.get("doc_to_target", None) is not None
        ):
            self.doc_to_target = partial(
                self.task.doc_to_target,
                doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
            )
        else:
            self.doc_to_target = self.task.doc_to_target

        if (
            self.config.fewshot_config is not None
            and self.config.fewshot_config.get("doc_to_choice", None) is not None
        ):
            self.doc_to_choice = partial(
                self.task.doc_to_choice,
                doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
            )
        else:
            self.doc_to_choice = self.task.doc_to_choice
72

lintangsutawika's avatar
lintangsutawika committed
73
74
        self.docs = docs  # HF dataset split, provided by task._fewshot_docs()
        if fewshot_indices:  # subset few-shot docs from
75
76
77
78
            if not isinstance(self.docs, datasets.Dataset):
                raise ValueError(
                    "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
                )
79
80
            self.docs = self.docs.select(fewshot_indices)

81
    def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str | None = None):
lintangsutawika's avatar
lintangsutawika committed
82
        # draw an extra fewshot sample if using same split as evaluating on
Baber Abbasi's avatar
Baber Abbasi committed
83
        prefix = gen_prefix + " " if gen_prefix else ""
lintangsutawika's avatar
lintangsutawika committed
84
85
86
87
88
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
89

90
        # draw `n_samples` docs from fewshot_docs
91
92
93
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
94
        # TODO: should we just stop people from using fewshot from same split as evaluating?
95
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
lintangsutawika's avatar
lintangsutawika committed
96

KonradSzafer's avatar
KonradSzafer committed
97
98
99
100
        labeled_examples = ""
        for doc in selected_docs:
            doc_content = self.doc_to_text(doc)
            doc_target = self.doc_to_target(doc)
101
102
103
            if (
                self.config.doc_to_choice is None and isinstance(doc_content, str)
            ) or isinstance(doc_content, str):
Baber Abbasi's avatar
Baber Abbasi committed
104
105
                labeled_examples += doc_content
            else:
106
107
                if isinstance(doc_content, int):
                    labeled_examples += self.doc_to_choice(doc)[doc_content]
108

Yu Shi Jie's avatar
Yu Shi Jie committed
109
            if doc_target != "":
110
111
112
113
114
115
116
                if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
                    # TODO: add logger warn once here.
                    warnings.warn(
                        "Both target_delimiter and target start with a space. This may cause issues.",
                        Warning,
                        stacklevel=2,
                    )
117
                labeled_examples += self.target_delimiter
Baber Abbasi's avatar
Baber Abbasi committed
118
                labeled_examples += prefix
Yu Shi Jie's avatar
Yu Shi Jie committed
119
120
121
122
123
124
125
126
                labeled_examples += (
                    str(doc_target[0])
                    if isinstance(doc_target, list)
                    else doc_target
                    if self.config.doc_to_choice is None or isinstance(doc_target, str)
                    else str(self.doc_to_choice(doc)[doc_target])
                )
                labeled_examples += self.fewshot_delimiter
127
128
129

        return labeled_examples

KonradSzafer's avatar
KonradSzafer committed
130
131
    def get_chat_context(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
132
133
        doc: dict,
        num_fewshot: int,
KonradSzafer's avatar
KonradSzafer committed
134
        fewshot_as_multiturn: bool = False,
135
        gen_prefix: str | None = None,
KonradSzafer's avatar
KonradSzafer committed
136
    ):
Baber Abbasi's avatar
Baber Abbasi committed
137
        # TODO: Do we need any other delimiter
Baber Abbasi's avatar
Baber Abbasi committed
138
        prefix = gen_prefix + " " if gen_prefix else ""
KonradSzafer's avatar
KonradSzafer committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        chat_history = []
        # draw an extra fewshot sample if using same split as evaluating on
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
        # draw `n_samples` docs from fewshot_docs
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
        # TODO: should we just stop people from using fewshot from same split as evaluating?
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

        if fewshot_as_multiturn:
            for doc in selected_docs:
                doc_content = self.doc_to_text(doc)
                doc_target = self.doc_to_target(doc)
                chat_history.append(
                    {
                        "role": "user",
                        "content": doc_content
                        if self.config.doc_to_choice is None
                        or isinstance(doc_content, str)
                        else self.doc_to_choice(doc)[doc_content],
                    }
                )
                chat_history.append(
                    {
                        "role": "assistant",
Baber Abbasi's avatar
Baber Abbasi committed
169
                        "content": prefix + str(doc_target[0])
KonradSzafer's avatar
KonradSzafer committed
170
                        if isinstance(doc_target, list)
Baber Abbasi's avatar
Baber Abbasi committed
171
                        else prefix + doc_target
KonradSzafer's avatar
KonradSzafer committed
172
173
                        if self.config.doc_to_choice is None
                        or isinstance(doc_target, str)
Baber Abbasi's avatar
Baber Abbasi committed
174
                        else prefix + str(self.doc_to_choice(doc)[doc_target]),
KonradSzafer's avatar
KonradSzafer committed
175
176
177
178
179
                    }
                )
        else:
            # get fewshot context as one user turn
            chat_history.append(
Baber Abbasi's avatar
Baber Abbasi committed
180
181
182
                {
                    "role": "user",
                    "content": self.get_context(
Baber Abbasi's avatar
Baber Abbasi committed
183
                        doc, num_fewshot, gen_prefix=gen_prefix
Baber Abbasi's avatar
Baber Abbasi committed
184
185
                    ),
                }
KonradSzafer's avatar
KonradSzafer committed
186
187
188
189
            )

        return chat_history

190
191
192
193
    # @classmethod
    # def from_fewshot_dfg(cls, cfg: FewshotConfig):
    #     if not

Baber's avatar
Baber committed
194
    def sample(self, n: int) -> Sequence[dict]:
195
196
197
        """
        Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
        """
198
199
200
        assert self.rnd is not None, (
            "Error: `rnd` must be set to a random.Random instance before sampling."
        )
201
202
203
        return self.rnd.sample(self.docs, n)


haileyschoelkopf's avatar
haileyschoelkopf committed
204
class FirstNSampler(ContextSampler):
205
    def sample(self, n: int) -> Sequence[dict[str, Any]]:
haileyschoelkopf's avatar
haileyschoelkopf committed
206
207
208
209
        """
        Draw the first `n` samples in order from the specified split.
        Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
        """
Baber Abbasi's avatar
Baber Abbasi committed
210
211
212
        assert n <= len(self.docs), (
            f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
        )
haileyschoelkopf's avatar
haileyschoelkopf committed
213
214
215
216
        return self.docs[:n]


class BalancedSampler(ContextSampler):
217
    def sample(self, n: int):
218
        """
lintangsutawika's avatar
lintangsutawika committed
219
        TODO: this should return approximately class-balanced samples from our fewshot examples.
220
        TODO: what order should they be in? maybe random?
221
222
        """

223
        raise NotImplementedError
224
225


haileyschoelkopf's avatar
haileyschoelkopf committed
226
class ManualSampler(ContextSampler):
227
    def sample(self, n: int):
lintangsutawika's avatar
lintangsutawika committed
228
        """ """
229
        raise NotImplementedError
230
231


232
SAMPLER_REGISTRY: dict[str, type[ContextSampler]] = {
haileyschoelkopf's avatar
haileyschoelkopf committed
233
234
235
236
237
    "default": ContextSampler,
    "first_n": FirstNSampler,
}


Baber Abbasi's avatar
Baber Abbasi committed
238
def get_sampler(name: str):
haileyschoelkopf's avatar
haileyschoelkopf committed
239
240
    try:
        return SAMPLER_REGISTRY[name]
241
242
    except KeyError as e:
        raise KeyError(
haileyschoelkopf's avatar
haileyschoelkopf committed
243
            f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
244
        ) from e