samplers.py 8.13 KB
Newer Older
1
2
import logging
import warnings
Yu Shi Jie's avatar
Yu Shi Jie committed
3
from functools import partial
Baber Abbasi's avatar
Baber Abbasi committed
4
from typing import TYPE_CHECKING, Iterable, Optional, Union
Yu Shi Jie's avatar
Yu Shi Jie committed
5

6
7
8
import datasets


Baber Abbasi's avatar
Baber Abbasi committed
9
10
11
12
13
if TYPE_CHECKING:
    from random import Random

    from lm_eval.api.task import ConfigurableTask, Task

14
15
eval_logger = logging.getLogger("lm-eval")

Baber Abbasi's avatar
Baber Abbasi committed
16

haileyschoelkopf's avatar
haileyschoelkopf committed
17
class ContextSampler:
Baber Abbasi's avatar
Baber Abbasi committed
18
19
20
21
22
23
24
    def __init__(
        self,
        docs: list[dict],
        task: Union["Task", "ConfigurableTask"],
        fewshot_indices: Optional[Iterable] = None,
        rnd: Optional["Random"] = None,
    ) -> None:
25
        self.rnd = rnd
26
27
28
29
        if not self.rnd:
            raise ValueError(
                "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
            )
30
31
32
33

        self.task = task
        self.config = task._config

Lintang Sutawika's avatar
Lintang Sutawika committed
34
35
        self.target_delimiter = self.config.target_delimiter
        self.fewshot_delimiter = self.config.fewshot_delimiter
36

Yu Shi Jie's avatar
Yu Shi Jie committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        if (
            self.config.fewshot_config is not None
            and self.config.fewshot_config.get("doc_to_text", None) is not None
        ):
            self.doc_to_text = partial(
                self.task.doc_to_text,
                doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
            )
        else:
            self.doc_to_text = self.task.doc_to_text

        if (
            self.config.fewshot_config is not None
            and self.config.fewshot_config.get("doc_to_target", None) is not None
        ):
            self.doc_to_target = partial(
                self.task.doc_to_target,
                doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
            )
        else:
            self.doc_to_target = self.task.doc_to_target

        if (
            self.config.fewshot_config is not None
            and self.config.fewshot_config.get("doc_to_choice", None) is not None
        ):
            self.doc_to_choice = partial(
                self.task.doc_to_choice,
                doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
            )
        else:
            self.doc_to_choice = self.task.doc_to_choice
69

lintangsutawika's avatar
lintangsutawika committed
70
71
        self.docs = docs  # HF dataset split, provided by task._fewshot_docs()
        if fewshot_indices:  # subset few-shot docs from
72
73
74
75
            if not isinstance(self.docs, datasets.Dataset):
                raise ValueError(
                    "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
                )
76
77
            self.docs = self.docs.select(fewshot_indices)

Baber Abbasi's avatar
Baber Abbasi committed
78
    def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
lintangsutawika's avatar
lintangsutawika committed
79
        # draw an extra fewshot sample if using same split as evaluating on
Baber Abbasi's avatar
Baber Abbasi committed
80
        prefix = gen_prefix + " " if gen_prefix else ""
lintangsutawika's avatar
lintangsutawika committed
81
82
83
84
85
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
86

87
        # draw `n_samples` docs from fewshot_docs
88
89
90
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
91
        # TODO: should we just stop people from using fewshot from same split as evaluating?
92
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
lintangsutawika's avatar
lintangsutawika committed
93

KonradSzafer's avatar
KonradSzafer committed
94
95
96
97
        labeled_examples = ""
        for doc in selected_docs:
            doc_content = self.doc_to_text(doc)
            doc_target = self.doc_to_target(doc)
Baber Abbasi's avatar
Baber Abbasi committed
98
99
100
101
            if self.config.doc_to_choice is None or isinstance(doc_content, str):
                labeled_examples += doc_content
            else:
                labeled_examples += self.doc_to_choice(doc)[doc_content]
102

Yu Shi Jie's avatar
Yu Shi Jie committed
103
            if doc_target != "":
104
105
106
107
108
109
110
                if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
                    # TODO: add logger warn once here.
                    warnings.warn(
                        "Both target_delimiter and target start with a space. This may cause issues.",
                        Warning,
                        stacklevel=2,
                    )
111
                labeled_examples += self.target_delimiter
Baber Abbasi's avatar
Baber Abbasi committed
112
                labeled_examples += prefix
Yu Shi Jie's avatar
Yu Shi Jie committed
113
114
115
116
117
118
119
120
                labeled_examples += (
                    str(doc_target[0])
                    if isinstance(doc_target, list)
                    else doc_target
                    if self.config.doc_to_choice is None or isinstance(doc_target, str)
                    else str(self.doc_to_choice(doc)[doc_target])
                )
                labeled_examples += self.fewshot_delimiter
121
122
123

        return labeled_examples

KonradSzafer's avatar
KonradSzafer committed
124
125
    def get_chat_context(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
126
127
        doc: dict,
        num_fewshot: int,
KonradSzafer's avatar
KonradSzafer committed
128
        fewshot_as_multiturn: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
129
        gen_prefix: Optional[str] = None,
KonradSzafer's avatar
KonradSzafer committed
130
    ):
Baber Abbasi's avatar
Baber Abbasi committed
131
        # TODO: Do we need any other delimiter
Baber Abbasi's avatar
Baber Abbasi committed
132
        prefix = gen_prefix + " " if gen_prefix else ""
KonradSzafer's avatar
KonradSzafer committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        chat_history = []
        # draw an extra fewshot sample if using same split as evaluating on
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
        # draw `n_samples` docs from fewshot_docs
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
        # TODO: should we just stop people from using fewshot from same split as evaluating?
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

        if fewshot_as_multiturn:
            for doc in selected_docs:
                doc_content = self.doc_to_text(doc)
                doc_target = self.doc_to_target(doc)
                chat_history.append(
                    {
                        "role": "user",
                        "content": doc_content
                        if self.config.doc_to_choice is None
                        or isinstance(doc_content, str)
                        else self.doc_to_choice(doc)[doc_content],
                    }
                )
                chat_history.append(
                    {
                        "role": "assistant",
Baber Abbasi's avatar
Baber Abbasi committed
163
                        "content": prefix + str(doc_target[0])
KonradSzafer's avatar
KonradSzafer committed
164
                        if isinstance(doc_target, list)
Baber Abbasi's avatar
Baber Abbasi committed
165
                        else prefix + doc_target
KonradSzafer's avatar
KonradSzafer committed
166
167
                        if self.config.doc_to_choice is None
                        or isinstance(doc_target, str)
Baber Abbasi's avatar
Baber Abbasi committed
168
                        else prefix + str(self.doc_to_choice(doc)[doc_target]),
KonradSzafer's avatar
KonradSzafer committed
169
170
171
172
173
                    }
                )
        else:
            # get fewshot context as one user turn
            chat_history.append(
Baber Abbasi's avatar
Baber Abbasi committed
174
175
176
                {
                    "role": "user",
                    "content": self.get_context(
Baber Abbasi's avatar
Baber Abbasi committed
177
                        doc, num_fewshot, gen_prefix=gen_prefix
Baber Abbasi's avatar
Baber Abbasi committed
178
179
                    ),
                }
KonradSzafer's avatar
KonradSzafer committed
180
181
182
183
            )

        return chat_history

Baber Abbasi's avatar
Baber Abbasi committed
184
    def sample(self, n: int):
185
186
187
188
189
190
191
        """
        Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
        """

        return self.rnd.sample(self.docs, n)


haileyschoelkopf's avatar
haileyschoelkopf committed
192
class FirstNSampler(ContextSampler):
Baber Abbasi's avatar
Baber Abbasi committed
193
    def sample(self, n: int) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
194
195
196
197
        """
        Draw the first `n` samples in order from the specified split.
        Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
        """
Baber Abbasi's avatar
Baber Abbasi committed
198
199
200
        assert n <= len(self.docs), (
            f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
        )
haileyschoelkopf's avatar
haileyschoelkopf committed
201
202
203
204
        return self.docs[:n]


class BalancedSampler(ContextSampler):
Baber Abbasi's avatar
Baber Abbasi committed
205
    def sample(self, n: int) -> None:
206
        """
lintangsutawika's avatar
lintangsutawika committed
207
        TODO: this should return approximately class-balanced samples from our fewshot examples.
208
        TODO: what order should they be in? maybe random?
209
210
211
212
213
        """

        pass


haileyschoelkopf's avatar
haileyschoelkopf committed
214
class ManualSampler(ContextSampler):
Baber Abbasi's avatar
Baber Abbasi committed
215
    def sample(self, n: int) -> None:
lintangsutawika's avatar
lintangsutawika committed
216
217
        """ """
        pass
218
219


haileyschoelkopf's avatar
haileyschoelkopf committed
220
221
222
223
224
225
SAMPLER_REGISTRY = {
    "default": ContextSampler,
    "first_n": FirstNSampler,
}


Baber Abbasi's avatar
Baber Abbasi committed
226
def get_sampler(name: str):
haileyschoelkopf's avatar
haileyschoelkopf committed
227
228
229
230
231
232
    try:
        return SAMPLER_REGISTRY[name]
    except KeyError:
        raise ValueError(
            f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
        )