samplers.py 6.02 KB
Newer Older
1
2
3
import datasets


haileyschoelkopf's avatar
haileyschoelkopf committed
4
class ContextSampler:
Ethan Smith's avatar
Ethan Smith committed
5
    def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
6
        self.rnd = rnd
7
8
9
10
        if not self.rnd:
            raise ValueError(
                "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
            )
11
12
13
14

        self.task = task
        self.config = task._config

Lintang Sutawika's avatar
Lintang Sutawika committed
15
16
        self.target_delimiter = self.config.target_delimiter
        self.fewshot_delimiter = self.config.fewshot_delimiter
17

18
19
20
21
22
        if self.config.fewshot_config is not None and self.config.fewshot_config.get("doc_to_text", None) is not None:
            self.doc_to_text = self.config.fewshot_config.get("doc_to_text", None)
        else:
            self.doc_to_text = self.task.doc_to_text

23
24
25
        self.doc_to_target = self.task.doc_to_target
        self.doc_to_choice = self.task.doc_to_choice

lintangsutawika's avatar
lintangsutawika committed
26
27
        self.docs = docs  # HF dataset split, provided by task._fewshot_docs()
        if fewshot_indices:  # subset few-shot docs from
28
29
30
31
            if not isinstance(self.docs, datasets.Dataset):
                raise ValueError(
                    "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
                )
32
33
34
            self.docs = self.docs.select(fewshot_indices)

    def get_context(self, doc, num_fewshot):
lintangsutawika's avatar
lintangsutawika committed
35
36
37
38
39
40
        # draw an extra fewshot sample if using same split as evaluating on
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
41

42
        # draw `n_samples` docs from fewshot_docs
43
44
45
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
46
        # TODO: should we just stop people from using fewshot from same split as evaluating?
47
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
lintangsutawika's avatar
lintangsutawika committed
48

KonradSzafer's avatar
KonradSzafer committed
49
50
51
52
53
54
55
56
        labeled_examples = ""
        for doc in selected_docs:
            doc_content = self.doc_to_text(doc)
            doc_target = self.doc_to_target(doc)
            labeled_examples += (
                doc_content
                if self.config.doc_to_choice is None or isinstance(doc_content, str)
                else self.doc_to_choice(doc)[doc_content]
57
            )
KonradSzafer's avatar
KonradSzafer committed
58
59
60
61
62
63
64
65
66
            labeled_examples += self.target_delimiter
            labeled_examples += (
                str(doc_target[0])
                if isinstance(doc_target, list)
                else doc_target
                if self.config.doc_to_choice is None or isinstance(doc_target, str)
                else str(self.doc_to_choice(doc)[doc_target])
            )
            labeled_examples += self.fewshot_delimiter
67
68
69

        return labeled_examples

KonradSzafer's avatar
KonradSzafer committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    def get_chat_context(
        self,
        doc,
        num_fewshot,
        fewshot_as_multiturn: bool = False,
    ):
        chat_history = []
        # draw an extra fewshot sample if using same split as evaluating on
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
        # draw `n_samples` docs from fewshot_docs
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
        # TODO: should we just stop people from using fewshot from same split as evaluating?
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

        if fewshot_as_multiturn:
            for doc in selected_docs:
                doc_content = self.doc_to_text(doc)
                doc_target = self.doc_to_target(doc)
                chat_history.append(
                    {
                        "role": "user",
                        "content": doc_content
                        if self.config.doc_to_choice is None
                        or isinstance(doc_content, str)
                        else self.doc_to_choice(doc)[doc_content],
                    }
                )
                chat_history.append(
                    {
                        "role": "assistant",
                        "content": str(doc_target[0])
                        if isinstance(doc_target, list)
                        else doc_target
                        if self.config.doc_to_choice is None
                        or isinstance(doc_target, str)
                        else str(self.doc_to_choice(doc)[doc_target]),
                    }
                )
        else:
            # get fewshot context as one user turn
            chat_history.append(
                {"role": "user", "content": self.get_context(doc, num_fewshot)}
            )

        return chat_history

122
123
124
125
126
127
128
129
    def sample(self, n):
        """
        Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
        """

        return self.rnd.sample(self.docs, n)


haileyschoelkopf's avatar
haileyschoelkopf committed
130
131
132
133
134
135
class FirstNSampler(ContextSampler):
    def sample(self, n) -> None:
        """
        Draw the first `n` samples in order from the specified split.
        Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
        """
136
137
        assert (
            n <= len(self.docs)
haileyschoelkopf's avatar
haileyschoelkopf committed
138
139
140
141
142
        ), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
        return self.docs[:n]


class BalancedSampler(ContextSampler):
Ethan Smith's avatar
Ethan Smith committed
143
    def sample(self, n) -> None:
144
        """
lintangsutawika's avatar
lintangsutawika committed
145
        TODO: this should return approximately class-balanced samples from our fewshot examples.
146
        TODO: what order should they be in? maybe random?
147
148
149
150
151
        """

        pass


haileyschoelkopf's avatar
haileyschoelkopf committed
152
class ManualSampler(ContextSampler):
Ethan Smith's avatar
Ethan Smith committed
153
    def sample(self, n) -> None:
lintangsutawika's avatar
lintangsutawika committed
154
155
        """ """
        pass
156
157


haileyschoelkopf's avatar
haileyschoelkopf committed
158
159
160
161
162
163
164
165
166
167
168
169
170
SAMPLER_REGISTRY = {
    "default": ContextSampler,
    "first_n": FirstNSampler,
}


def get_sampler(name):
    try:
        return SAMPLER_REGISTRY[name]
    except KeyError:
        raise ValueError(
            f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
        )