samplers.py 2.76 KB
Newer Older
1
2
from __future__ import annotations

3
import logging
Baber Abbasi's avatar
Baber Abbasi committed
4
5
from random import Random
from typing import TYPE_CHECKING
6
7


Baber Abbasi's avatar
Baber Abbasi committed
8
if TYPE_CHECKING:
Baber Abbasi's avatar
Baber Abbasi committed
9
10
    from collections.abc import Iterable, Sequence
    from typing import Any, TypeVar
Baber Abbasi's avatar
Baber Abbasi committed
11

Baber Abbasi's avatar
Baber Abbasi committed
12
    _T = TypeVar("_T")
Baber Abbasi's avatar
Baber Abbasi committed
13

Baber Abbasi's avatar
Baber Abbasi committed
14
eval_logger = logging.getLogger(__name__)
15

Baber Abbasi's avatar
Baber Abbasi committed
16

haileyschoelkopf's avatar
haileyschoelkopf committed
17
class ContextSampler:
Baber Abbasi's avatar
Baber Abbasi committed
18
19
    def __init__(
        self,
Baber Abbasi's avatar
Baber Abbasi committed
20
21
22
23
24
        docs: Sequence[dict[str, Any]] | None = None,
        *,
        rnd: int | None = None,
        fewshot_indices: list[int] | None = None,
        **kwargs,
Baber Abbasi's avatar
Baber Abbasi committed
25
    ) -> None:
Baber Abbasi's avatar
Baber Abbasi committed
26
27
28
        self.rnd = Random(rnd)
        self.docs = docs or []
        self.fewshot_indices = fewshot_indices
29

Baber Abbasi's avatar
Baber Abbasi committed
30
31
32
33
34
35
        if self.fewshot_indices and self.docs:
            self.docs = [self.docs[i] for i in self.fewshot_indices]

    def sample(
        self, n: int, doc: dict[str, Any] | None = None, **kwargs
    ) -> Sequence[dict]:
36
        """
Baber Abbasi's avatar
Baber Abbasi committed
37
38
39
40
41
42
43
44
        Sample n documents from the pool.

        Args:
            n: Number of documents to sample
            doc: Optional document to exclude from sampling

        Returns:
            List of sampled documents
45
        """
Baber Abbasi's avatar
Baber Abbasi committed
46
47
48
49
50
51
        if n <= 0:
            return []
        return (
            self.rnd.sample(self.docs, n)
            if not doc
            else self.remove_doc(doc, self.rnd.sample(self.docs, n + 1))
52
        )
Baber Abbasi's avatar
Baber Abbasi committed
53
54
55
56
57
58
59

    def set_rnd(self, rnd: int) -> None:
        self.rnd = Random(rnd)

    @staticmethod
    def remove_doc(doc: _T, _iter: Iterable[_T]) -> list[_T]:
        return [x for x in _iter if x != doc]
60
61


haileyschoelkopf's avatar
haileyschoelkopf committed
62
class FirstNSampler(ContextSampler):
Baber Abbasi's avatar
Baber Abbasi committed
63
    def sample(self, n: int, doc=None, **kwargs):
haileyschoelkopf's avatar
haileyschoelkopf committed
64
65
66
67
        """
        Draw the first `n` samples in order from the specified split.
        Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
        """
Baber Abbasi's avatar
Baber Abbasi committed
68
69
70
        assert n <= len(self.docs), (
            f"Error: number of fewshot samples requested exceeds the {len(self.docs)} that are available."
        )
haileyschoelkopf's avatar
haileyschoelkopf committed
71
72
73
74
        return self.docs[:n]


class BalancedSampler(ContextSampler):
Baber Abbasi's avatar
Baber Abbasi committed
75
    def sample(self, n: int, doc=None, **kwargs):
76
        """
lintangsutawika's avatar
lintangsutawika committed
77
        TODO: this should return approximately class-balanced samples from our fewshot examples.
78
        TODO: what order should they be in? maybe random?
79
80
        """

81
        raise NotImplementedError
82
83


haileyschoelkopf's avatar
haileyschoelkopf committed
84
class ManualSampler(ContextSampler):
Baber Abbasi's avatar
Baber Abbasi committed
85
    def sample(self, n: int, doc=None, **kwargs):
lintangsutawika's avatar
lintangsutawika committed
86
        """ """
87
        raise NotImplementedError
88
89


90
SAMPLER_REGISTRY: dict[str, type[ContextSampler]] = {
haileyschoelkopf's avatar
haileyschoelkopf committed
91
92
93
94
95
    "default": ContextSampler,
    "first_n": FirstNSampler,
}


Baber Abbasi's avatar
Baber Abbasi committed
96
def get_sampler(name: str):
haileyschoelkopf's avatar
haileyschoelkopf committed
97
98
    try:
        return SAMPLER_REGISTRY[name]
99
100
    except KeyError as e:
        raise KeyError(
haileyschoelkopf's avatar
haileyschoelkopf committed
101
            f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
102
        ) from e