samplers.py 7.75 KB
Newer Older
Yu Shi Jie's avatar
Yu Shi Jie committed
1
from functools import partial
Baber Abbasi's avatar
Baber Abbasi committed
2
from typing import TYPE_CHECKING, Iterable, Optional, Union
Yu Shi Jie's avatar
Yu Shi Jie committed
3

4
5
6
import datasets


Baber Abbasi's avatar
Baber Abbasi committed
7
8
9
10
11
12
if TYPE_CHECKING:
    from random import Random

    from lm_eval.api.task import ConfigurableTask, Task


haileyschoelkopf's avatar
haileyschoelkopf committed
13
class ContextSampler:
Baber Abbasi's avatar
Baber Abbasi committed
14
15
16
17
18
19
20
    def __init__(
        self,
        docs: list[dict],
        task: Union["Task", "ConfigurableTask"],
        fewshot_indices: Optional[Iterable] = None,
        rnd: Optional["Random"] = None,
    ) -> None:
21
        self.rnd = rnd
22
23
24
25
        if not self.rnd:
            raise ValueError(
                "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
            )
26
27
28
29

        self.task = task
        self.config = task._config

Lintang Sutawika's avatar
Lintang Sutawika committed
30
31
        self.target_delimiter = self.config.target_delimiter
        self.fewshot_delimiter = self.config.fewshot_delimiter
32

Yu Shi Jie's avatar
Yu Shi Jie committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
        if (
            self.config.fewshot_config is not None
            and self.config.fewshot_config.get("doc_to_text", None) is not None
        ):
            self.doc_to_text = partial(
                self.task.doc_to_text,
                doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
            )
        else:
            self.doc_to_text = self.task.doc_to_text

        if (
            self.config.fewshot_config is not None
            and self.config.fewshot_config.get("doc_to_target", None) is not None
        ):
            self.doc_to_target = partial(
                self.task.doc_to_target,
                doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
            )
        else:
            self.doc_to_target = self.task.doc_to_target

        if (
            self.config.fewshot_config is not None
            and self.config.fewshot_config.get("doc_to_choice", None) is not None
        ):
            self.doc_to_choice = partial(
                self.task.doc_to_choice,
                doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
            )
        else:
            self.doc_to_choice = self.task.doc_to_choice
65

lintangsutawika's avatar
lintangsutawika committed
66
67
        self.docs = docs  # HF dataset split, provided by task._fewshot_docs()
        if fewshot_indices:  # subset few-shot docs from
68
69
70
71
            if not isinstance(self.docs, datasets.Dataset):
                raise ValueError(
                    "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
                )
72
73
            self.docs = self.docs.select(fewshot_indices)

Baber Abbasi's avatar
Baber Abbasi committed
74
    def get_context(self, doc: dict, num_fewshot: int, assistant_prefill: str = None):
lintangsutawika's avatar
lintangsutawika committed
75
        # draw an extra fewshot sample if using same split as evaluating on
Baber Abbasi's avatar
Baber Abbasi committed
76
        prefix = assistant_prefill + " " if assistant_prefill else ""
lintangsutawika's avatar
lintangsutawika committed
77
78
79
80
81
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
82

83
        # draw `n_samples` docs from fewshot_docs
84
85
86
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
87
        # TODO: should we just stop people from using fewshot from same split as evaluating?
88
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
lintangsutawika's avatar
lintangsutawika committed
89

KonradSzafer's avatar
KonradSzafer committed
90
91
92
93
        labeled_examples = ""
        for doc in selected_docs:
            doc_content = self.doc_to_text(doc)
            doc_target = self.doc_to_target(doc)
Baber Abbasi's avatar
Baber Abbasi committed
94
95
96
97
            if self.config.doc_to_choice is None or isinstance(doc_content, str):
                labeled_examples += doc_content
            else:
                labeled_examples += self.doc_to_choice(doc)[doc_content]
98

Yu Shi Jie's avatar
Yu Shi Jie committed
99
            if doc_target != "":
100
                labeled_examples += self.target_delimiter
Baber Abbasi's avatar
Baber Abbasi committed
101
                labeled_examples += prefix
Yu Shi Jie's avatar
Yu Shi Jie committed
102
103
104
105
106
107
108
109
                labeled_examples += (
                    str(doc_target[0])
                    if isinstance(doc_target, list)
                    else doc_target
                    if self.config.doc_to_choice is None or isinstance(doc_target, str)
                    else str(self.doc_to_choice(doc)[doc_target])
                )
                labeled_examples += self.fewshot_delimiter
110
111
112

        return labeled_examples

KonradSzafer's avatar
KonradSzafer committed
113
114
    def get_chat_context(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
115
116
        doc: dict,
        num_fewshot: int,
KonradSzafer's avatar
KonradSzafer committed
117
        fewshot_as_multiturn: bool = False,
Baber Abbasi's avatar
Baber Abbasi committed
118
        assistant_prefill: Optional[str] = None,
KonradSzafer's avatar
KonradSzafer committed
119
    ):
Baber Abbasi's avatar
Baber Abbasi committed
120
121
        # TODO: Do we need any other delimiter
        prefix = assistant_prefill + " " if assistant_prefill else ""
KonradSzafer's avatar
KonradSzafer committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        chat_history = []
        # draw an extra fewshot sample if using same split as evaluating on
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
        # draw `n_samples` docs from fewshot_docs
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
        # TODO: should we just stop people from using fewshot from same split as evaluating?
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

        if fewshot_as_multiturn:
            for doc in selected_docs:
                doc_content = self.doc_to_text(doc)
                doc_target = self.doc_to_target(doc)
                chat_history.append(
                    {
                        "role": "user",
                        "content": doc_content
                        if self.config.doc_to_choice is None
                        or isinstance(doc_content, str)
                        else self.doc_to_choice(doc)[doc_content],
                    }
                )
                chat_history.append(
                    {
                        "role": "assistant",
Baber Abbasi's avatar
Baber Abbasi committed
152
                        "content": prefix + str(doc_target[0])
KonradSzafer's avatar
KonradSzafer committed
153
                        if isinstance(doc_target, list)
Baber Abbasi's avatar
Baber Abbasi committed
154
                        else prefix + doc_target
KonradSzafer's avatar
KonradSzafer committed
155
156
                        if self.config.doc_to_choice is None
                        or isinstance(doc_target, str)
Baber Abbasi's avatar
Baber Abbasi committed
157
                        else prefix + str(self.doc_to_choice(doc)[doc_target]),
KonradSzafer's avatar
KonradSzafer committed
158
159
160
161
162
                    }
                )
        else:
            # get fewshot context as one user turn
            chat_history.append(
Baber Abbasi's avatar
Baber Abbasi committed
163
164
165
166
167
168
                {
                    "role": "user",
                    "content": self.get_context(
                        doc, num_fewshot, assistant_prefill=assistant_prefill
                    ),
                }
KonradSzafer's avatar
KonradSzafer committed
169
170
171
172
            )

        return chat_history

Baber Abbasi's avatar
Baber Abbasi committed
173
    def sample(self, n: int):
174
175
176
177
178
179
180
        """
        Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
        """

        return self.rnd.sample(self.docs, n)


haileyschoelkopf's avatar
haileyschoelkopf committed
181
class FirstNSampler(ContextSampler):
Baber Abbasi's avatar
Baber Abbasi committed
182
    def sample(self, n: int) -> None:
haileyschoelkopf's avatar
haileyschoelkopf committed
183
184
185
186
        """
        Draw the first `n` samples in order from the specified split.
        Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
        """
187
188
        assert (
            n <= len(self.docs)
haileyschoelkopf's avatar
haileyschoelkopf committed
189
190
191
192
193
        ), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
        return self.docs[:n]


class BalancedSampler(ContextSampler):
Baber Abbasi's avatar
Baber Abbasi committed
194
    def sample(self, n: int) -> None:
195
        """
lintangsutawika's avatar
lintangsutawika committed
196
        TODO: this should return approximately class-balanced samples from our fewshot examples.
197
        TODO: what order should they be in? maybe random?
198
199
200
201
202
        """

        pass


haileyschoelkopf's avatar
haileyschoelkopf committed
203
class ManualSampler(ContextSampler):
Baber Abbasi's avatar
Baber Abbasi committed
204
    def sample(self, n: int) -> None:
lintangsutawika's avatar
lintangsutawika committed
205
206
        """ """
        pass
207
208


haileyschoelkopf's avatar
haileyschoelkopf committed
209
210
211
212
213
214
SAMPLER_REGISTRY = {
    "default": ContextSampler,
    "first_n": FirstNSampler,
}


Baber Abbasi's avatar
Baber Abbasi committed
215
def get_sampler(name: str):
haileyschoelkopf's avatar
haileyschoelkopf committed
216
217
218
219
220
221
    try:
        return SAMPLER_REGISTRY[name]
    except KeyError:
        raise ValueError(
            f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
        )