samplers.py 4.27 KB
Newer Older
1
2
3
import datasets


haileyschoelkopf's avatar
haileyschoelkopf committed
4
class ContextSampler:
Ethan Smith's avatar
Ethan Smith committed
5
    def __init__(self, docs, task, fewshot_indices=None, rnd=None) -> None:
6
        self.rnd = rnd
7
8
9
10
        if not self.rnd:
            raise ValueError(
                "A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
            )
11
12
13
14

        self.task = task
        self.config = task._config

Lintang Sutawika's avatar
Lintang Sutawika committed
15
16
        self.target_delimiter = self.config.target_delimiter
        self.fewshot_delimiter = self.config.fewshot_delimiter
17

18
19
20
21
        self.doc_to_text = self.task.doc_to_text
        self.doc_to_target = self.task.doc_to_target
        self.doc_to_choice = self.task.doc_to_choice

lintangsutawika's avatar
lintangsutawika committed
22
23
        self.docs = docs  # HF dataset split, provided by task._fewshot_docs()
        if fewshot_indices:  # subset few-shot docs from
24
25
26
27
            if not isinstance(self.docs, datasets.Dataset):
                raise ValueError(
                    "Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
                )
28
29
30
            self.docs = self.docs.select(fewshot_indices)

    def get_context(self, doc, num_fewshot):
lintangsutawika's avatar
lintangsutawika committed
31
32
33
34
35
36
        # draw an extra fewshot sample if using same split as evaluating on
        n_samples = (
            num_fewshot + 1
            if self.config.fewshot_split == self.config.test_split
            else num_fewshot
        )
37

38
        # draw `n_samples` docs from fewshot_docs
39
40
41
        fewshotex = self.sample(n_samples)

        # get rid of the doc that's the one we're evaluating, if it's in the fewshot
42
        # TODO: should we just stop people from using fewshot from same split as evaluating?
43
        selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
lintangsutawika's avatar
lintangsutawika committed
44

45
        labeled_examples = (
46
            self.fewshot_delimiter.join(
lintangsutawika's avatar
lintangsutawika committed
47
                [
48
                    # TODO: is separating doc_to_text and doc_to_target by one space always desired?
49
50
51
52
                    (
                        self.doc_to_text(doc)
                        if (
                            self.config.doc_to_choice is None
53
                            or isinstance(self.doc_to_text(doc), str)
54
55
56
                        )
                        else self.doc_to_choice(doc)[self.doc_to_text(doc)]
                    )
57
                    + self.target_delimiter
58
                    + (
baberabb's avatar
baberabb committed
59
                        str(self.doc_to_target(doc)[0])
60
                        if isinstance(self.doc_to_target(doc), list)
61
                        else self.doc_to_target(doc)
62
63
                        if (
                            self.config.doc_to_choice is None
64
                            or isinstance(self.doc_to_target(doc), str)
65
                        )
baberabb's avatar
baberabb committed
66
                        else str(self.doc_to_choice(doc)[self.doc_to_target(doc)])
67
                    )
lintangsutawika's avatar
lintangsutawika committed
68
69
                    for doc in selected_docs
                ]
70
            )
Hailey Schoelkopf's avatar
Hailey Schoelkopf committed
71
            + self.fewshot_delimiter
lintangsutawika's avatar
lintangsutawika committed
72
        )
73
74
75
76
77
78
79
80
81
82
83

        return labeled_examples

    def sample(self, n):
        """
        Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
        """

        return self.rnd.sample(self.docs, n)


haileyschoelkopf's avatar
haileyschoelkopf committed
84
85
86
87
88
89
class FirstNSampler(ContextSampler):
    def sample(self, n) -> None:
        """
        Draw the first `n` samples in order from the specified split.
        Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
        """
90
91
        assert (
            n <= len(self.docs)
haileyschoelkopf's avatar
haileyschoelkopf committed
92
93
94
95
96
        ), f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
        return self.docs[:n]


class BalancedSampler(ContextSampler):
Ethan Smith's avatar
Ethan Smith committed
97
    def sample(self, n) -> None:
98
        """
lintangsutawika's avatar
lintangsutawika committed
99
        TODO: this should return approximately class-balanced samples from our fewshot examples.
100
        TODO: what order should they be in? maybe random?
101
102
103
104
105
        """

        pass


haileyschoelkopf's avatar
haileyschoelkopf committed
106
class ManualSampler(ContextSampler):
Ethan Smith's avatar
Ethan Smith committed
107
    def sample(self, n) -> None:
lintangsutawika's avatar
lintangsutawika committed
108
109
        """ """
        pass
110
111


haileyschoelkopf's avatar
haileyschoelkopf committed
112
113
114
115
116
117
118
119
120
121
122
123
124
SAMPLER_REGISTRY = {
    "default": ContextSampler,
    "first_n": FirstNSampler,
}


def get_sampler(name):
    try:
        return SAMPLER_REGISTRY[name]
    except KeyError:
        raise ValueError(
            f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
        )