samplers.py 5.66 KB
Newer Older
haileyschoelkopf's avatar
haileyschoelkopf committed
1
class ContextSampler:
Ethan Smith's avatar
Ethan Smith committed
2
    def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
3
        self.rnd = rnd
4
5
6
7
        if not self.rnd:
            raise ValueError(
                "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
            )
8
9
10
11

        self.task = task
        self.config = task._config

Lintang Sutawika's avatar
Lintang Sutawika committed
12
13
        self.target_delimiter = self.config.target_delimiter
        self.fewshot_delimiter = self.config.fewshot_delimiter
14

15
16
17
18
        self.doc_to_text = self.task.doc_to_text
        self.doc_to_target = self.task.doc_to_target
        self.doc_to_choice = self.task.doc_to_choice

lintangsutawika's avatar
lintangsutawika committed
19
20
        self.docs = docs  # HF dataset split, provided by task._fewshot_docs()
        if fewshot_indices:  # subset few-shot docs from
21
22
23
            self.docs = self.docs.select(fewshot_indices)

    def get_context(self, doc, num_fewshot):
lintangsutawika's avatar
lintangsutawika committed
24
25
26
27
28
29
        # draw an extra fewshot sample if using same split as evaluating on
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
30

31
        # draw `n_samples` docs from fewshot_docs
32
33
34
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
35
        # TODO: should we just stop people from using fewshot from same split as evaluating?
36
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
lintangsutawika's avatar
lintangsutawika committed
37

38
        labeled_examples = (
39
            self.fewshot_delimiter.join(
lintangsutawika's avatar
lintangsutawika committed
40
                [
41
                    # TODO: is separating doc_to_text and doc_to_target by one space always desired?
42
43
44
45
                    (
                        self.doc_to_text(doc)
                        if (
                            self.config.doc_to_choice is None
46
                            or isinstance(self.doc_to_text(doc), str)
47
48
49
                        )
                        else self.doc_to_choice(doc)[self.doc_to_text(doc)]
                    )
50
                    + self.target_delimiter
51
                    + (
baberabb's avatar
baberabb committed
52
                        str(self.doc_to_target(doc)[0])
53
                        if isinstance(self.doc_to_target(doc), list)
54
                        else self.doc_to_target(doc)
55
56
                        if (
                            self.config.doc_to_choice is None
57
                            or isinstance(self.doc_to_target(doc), str)
58
                        )
baberabb's avatar
baberabb committed
59
                        else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
60
                    )
lintangsutawika's avatar
lintangsutawika committed
61
62
                    for doc in selected_docs
                ]
63
            )
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
64
            + self.fewshot_delimiter
lintangsutawika's avatar
lintangsutawika committed
65
        )
66
67
        return labeled_examples

Konrad's avatar
Konrad committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    def get_chat_context(
        self,
        doc,
        num_fewshot,
        chat_history: list = [],
    ):
        # draw an extra fewshot sample if using same split as evaluating on
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
        # draw `n_samples` docs from fewshot_docs
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
        # TODO: should we just stop people from using fewshot from same split as evaluating?
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]

        for doc in selected_docs:
            chat_history.append(
                {
                    "role": "user",
                    "content": self.doc_to_text(doc)
                    if (
                        self.config.doc_to_choice is None
                        or isinstance(self.doc_to_text(doc), str)
                    )
                    else self.doc_to_choice(doc)[self.doc_to_text(doc)],
                }
            )
            chat_history.append(
                {
                    "role": "assistant",
                    "content": str(self.doc_to_target(doc)[0])
                    if isinstance(self.doc_to_target(doc), list)
                    else self.doc_to_target(doc)
                    if (
                        self.config.doc_to_choice is None
                        or isinstance(self.doc_to_target(doc), str)
                    )
                    else str(self.doc_to_choice(doc)[self.doc_to_target(doc)]),
                }
            )
        return chat_history

114
115
116
117
118
119
120
121
    def sample(self, n):
        """
        Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
        """

        return self.rnd.sample(self.docs, n)


haileyschoelkopf's avatar
haileyschoelkopf committed
122
123
124
125
126
127
class FirstNSampler(ContextSampler):
    def sample(self, n) -> None:
        """
        Draw the first `n` samples in order from the specified split.
        Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
        """
128
129
        assert (
            n <= len(self.docs)
haileyschoelkopf's avatar
haileyschoelkopf committed
130
131
132
133
134
        ), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
        return self.docs[:n]


class BalancedSampler(ContextSampler):
Ethan Smith's avatar
Ethan Smith committed
135
    def sample(self, n) -> None:
136
        """
lintangsutawika's avatar
lintangsutawika committed
137
        TODO: this should return approximately class-balanced samples from our fewshot examples.
138
        TODO: what order should they be in? maybe random?
139
140
141
142
143
        """

        pass


haileyschoelkopf's avatar
haileyschoelkopf committed
144
class ManualSampler(ContextSampler):
Ethan Smith's avatar
Ethan Smith committed
145
    def sample(self, n) -> None:
lintangsutawika's avatar
lintangsutawika committed
146
147
        """ """
        pass
148
149


haileyschoelkopf's avatar
haileyschoelkopf committed
150
151
152
153
154
155
156
157
158
159
160
161
162
SAMPLER_REGISTRY = {
    "default": ContextSampler,
    "first_n": FirstNSampler,
}


def get_sampler(name):
    try:
        return SAMPLER_REGISTRY[name]
    except KeyError:
        raise ValueError(
            f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
        )