sampler.py 7.33 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
import logging
from collections import defaultdict
from itertools import accumulate, cycle
from typing import List

import torch
from torch.utils.data import BatchSampler, ConcatDataset, SubsetRandomSampler

from sentence_transformers.util import is_datasets_available

if is_datasets_available():
    from datasets import Dataset

logger = logging.getLogger(__name__)


class SetEpochMixin:
    """
    Required for a BatchSampler as the Trainer will call set_epoch on the BatchSampler at the beginning of each epoch.
    The BatchSampler can then set the generator seed accordingly.
    """

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.epoch = 0

    def set_epoch(self, epoch: int):
        self.epoch = epoch


class DefaultBatchSampler(SetEpochMixin, BatchSampler):
    pass


class GroupByLabelBatchSampler(SetEpochMixin, BatchSampler):
    def __init__(
        self,
        dataset: "Dataset",
        batch_size: int,
        drop_last: bool,
        valid_label_columns: List[str] = None,
        generator: torch.Generator = None,
        seed: int = 0,
    ):
        super().__init__(dataset, batch_size, drop_last)
        self.dataset = dataset
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.generator = generator
        self.seed = seed

        if self.batch_size % 2 == 1:
            raise ValueError("The batch size for `GroupByLabelBatchSampler` must be divisible by 2.")

        for column_name in valid_label_columns or []:
            if column_name in dataset.column_names:
                labels = dataset["label"]
                break
        else:
            raise ValueError(f"None of the valid_label_columns {valid_label_columns} are in the dataset.")

        del dataset
        groups = defaultdict(list)
        for sample_idx, label in enumerate(labels):
            groups[label].append(sample_idx)

        self.groups = {
            label: sample_indices[:num_samples]
            for label, sample_indices in groups.items()
            if (num_samples := len(sample_indices) // 2)
        }

    def __iter__(self):
        if self.generator and self.seed:
            self.generator.manual_seed(self.seed + self.epoch)

        labels = list(self.groups.keys())
        partial_batch = []
        for label_idx in torch.randperm(len(self.groups), generator=self.generator):
            label = labels[label_idx]
            samples = self.groups[label]
            partial_batch.extend(samples)
            while len(partial_batch) >= self.batch_size:
                yield partial_batch[: self.batch_size]
                partial_batch = partial_batch[self.batch_size :]

        if not self.drop_last and partial_batch:
            yield partial_batch


class NoDuplicatesBatchSampler(SetEpochMixin, BatchSampler):
    def __init__(
        self,
        dataset: "Dataset",
        batch_size: int,
        drop_last: bool,
        valid_label_columns: List[str] = [],
        generator: torch.Generator = None,
        seed: int = 0,
    ):
        super().__init__(dataset, batch_size, drop_last)
        if label_columns := set(dataset.column_names) & (set(valid_label_columns) | {"dataset_name"}):
            dataset = dataset.remove_columns(label_columns)
        self.dataset = dataset
        self.batch_size = batch_size
        self.drop_last = drop_last
        self.generator = generator
        self.seed = seed

    def __iter__(self):
        """
        Iterate over the remaining non-yielded indices. For each index, check if the sample values are already in the
        batch. If not, add the sample values to the batch keep going until the batch is full. If the batch is full, yield
        the batch indices and continue with the next batch.
        """
        if self.generator and self.seed:
            self.generator.manual_seed(self.seed + self.epoch)

        remaining_indices = set(torch.randperm(len(self.dataset), generator=self.generator).tolist())
        while remaining_indices:
            batch_values = set()
            batch_indices = []
            for index in remaining_indices:
                sample_values = set(self.dataset[index].values())
                if sample_values & batch_values:
                    continue

                batch_indices.append(index)
                if len(batch_indices) == self.batch_size:
                    yield batch_indices
                    break

                batch_values.update(sample_values)

            else:
                # NOTE: some indices might still have been ignored here
                if not self.drop_last:
                    yield batch_indices

            remaining_indices -= set(batch_indices)

    def __len__(self) -> int:
        if self.drop_last:
            return len(self.dataset) // self.batch_size
        else:
            return (len(self.dataset) + self.batch_size - 1) // self.batch_size


class RoundRobinBatchSampler(SetEpochMixin, BatchSampler):
    def __init__(
        self,
        dataset: ConcatDataset,
        batch_samplers: List[BatchSampler],
        generator: torch.Generator,
        seed: int,
    ):
        super().__init__(dataset, batch_samplers[0].batch_size, batch_samplers[0].drop_last)
        self.dataset = dataset
        self.batch_samplers = batch_samplers
        self.generator = generator
        self.seed = seed

    def __iter__(self):
        self.generator.manual_seed(self.seed + self.epoch)

        num_samples = [len(dataset) for dataset in self.dataset.datasets]
        sample_offsets = [0] + list(accumulate(num_samples))

        batch_samplers = [iter(sampler) for sampler in self.batch_samplers]
        for dataset_idx in cycle(range(len(batch_samplers))):
            sample_offset = sample_offsets[dataset_idx]
            try:
                yield [idx + sample_offset for idx in next(batch_samplers[dataset_idx])]

            except StopIteration:
                # current iterator is apparently exhausted
                break

    def __len__(self) -> int:
        return min([len(sampler) for sampler in self.batch_samplers]) * len(self.batch_samplers)


class ProportionalBatchSampler(SetEpochMixin, BatchSampler):
    def __init__(
        self,
        dataset: ConcatDataset,
        batch_samplers: List[BatchSampler],
        generator: torch.Generator,
        seed: int,
    ):
        super().__init__(dataset, batch_samplers[0].batch_size, batch_samplers[0].drop_last)
        self.dataset = dataset
        self.batch_samplers = batch_samplers
        self.generator = generator
        self.seed = seed

    def __iter__(self):
        self.generator.manual_seed(self.seed + self.epoch)

        num_samples = [len(dataset) for dataset in self.dataset.datasets]
        sample_offsets = [0] + list(accumulate(num_samples))

        num_batches = [len(sampler) for sampler in self.batch_samplers]
        dataset_indices = [idx for idx, length in enumerate(num_batches) for _ in range(length)]
        dataset_idx_sampler = SubsetRandomSampler(dataset_indices, generator=self.generator)

        batch_samplers = [iter(sampler) for sampler in self.batch_samplers]
        for dataset_idx in dataset_idx_sampler:
            sample_offset = sample_offsets[dataset_idx]
            yield [idx + sample_offset for idx in next(batch_samplers[dataset_idx])]

    def __len__(self) -> int:
        return sum([len(sampler) for sampler in self.batch_samplers])