import logging from collections import defaultdict from itertools import accumulate, cycle from typing import List import torch from torch.utils.data import BatchSampler, ConcatDataset, SubsetRandomSampler from sentence_transformers.util import is_datasets_available if is_datasets_available(): from datasets import Dataset logger = logging.getLogger(__name__) class SetEpochMixin: """ Required for a BatchSampler as the Trainer will call set_epoch on the BatchSampler at the beginning of each epoch. The BatchSampler can then set the generator seed accordingly. """ def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.epoch = 0 def set_epoch(self, epoch: int): self.epoch = epoch class DefaultBatchSampler(SetEpochMixin, BatchSampler): pass class GroupByLabelBatchSampler(SetEpochMixin, BatchSampler): def __init__( self, dataset: "Dataset", batch_size: int, drop_last: bool, valid_label_columns: List[str] = None, generator: torch.Generator = None, seed: int = 0, ): super().__init__(dataset, batch_size, drop_last) self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last self.generator = generator self.seed = seed if self.batch_size % 2 == 1: raise ValueError("The batch size for `GroupByLabelBatchSampler` must be divisible by 2.") for column_name in valid_label_columns or []: if column_name in dataset.column_names: labels = dataset["label"] break else: raise ValueError(f"None of the valid_label_columns {valid_label_columns} are in the dataset.") del dataset groups = defaultdict(list) for sample_idx, label in enumerate(labels): groups[label].append(sample_idx) self.groups = { label: sample_indices[:num_samples] for label, sample_indices in groups.items() if (num_samples := len(sample_indices) // 2) } def __iter__(self): if self.generator and self.seed: self.generator.manual_seed(self.seed + self.epoch) labels = list(self.groups.keys()) partial_batch = [] for label_idx in torch.randperm(len(self.groups), generator=self.generator): label = labels[label_idx] samples = self.groups[label] partial_batch.extend(samples) while len(partial_batch) >= self.batch_size: yield partial_batch[: self.batch_size] partial_batch = partial_batch[self.batch_size :] if not self.drop_last and partial_batch: yield partial_batch class NoDuplicatesBatchSampler(SetEpochMixin, BatchSampler): def __init__( self, dataset: "Dataset", batch_size: int, drop_last: bool, valid_label_columns: List[str] = [], generator: torch.Generator = None, seed: int = 0, ): super().__init__(dataset, batch_size, drop_last) if label_columns := set(dataset.column_names) & (set(valid_label_columns) | {"dataset_name"}): dataset = dataset.remove_columns(label_columns) self.dataset = dataset self.batch_size = batch_size self.drop_last = drop_last self.generator = generator self.seed = seed def __iter__(self): """ Iterate over the remaining non-yielded indices. For each index, check if the sample values are already in the batch. If not, add the sample values to the batch keep going until the batch is full. If the batch is full, yield the batch indices and continue with the next batch. """ if self.generator and self.seed: self.generator.manual_seed(self.seed + self.epoch) remaining_indices = set(torch.randperm(len(self.dataset), generator=self.generator).tolist()) while remaining_indices: batch_values = set() batch_indices = [] for index in remaining_indices: sample_values = set(self.dataset[index].values()) if sample_values & batch_values: continue batch_indices.append(index) if len(batch_indices) == self.batch_size: yield batch_indices break batch_values.update(sample_values) else: # NOTE: some indices might still have been ignored here if not self.drop_last: yield batch_indices remaining_indices -= set(batch_indices) def __len__(self) -> int: if self.drop_last: return len(self.dataset) // self.batch_size else: return (len(self.dataset) + self.batch_size - 1) // self.batch_size class RoundRobinBatchSampler(SetEpochMixin, BatchSampler): def __init__( self, dataset: ConcatDataset, batch_samplers: List[BatchSampler], generator: torch.Generator, seed: int, ): super().__init__(dataset, batch_samplers[0].batch_size, batch_samplers[0].drop_last) self.dataset = dataset self.batch_samplers = batch_samplers self.generator = generator self.seed = seed def __iter__(self): self.generator.manual_seed(self.seed + self.epoch) num_samples = [len(dataset) for dataset in self.dataset.datasets] sample_offsets = [0] + list(accumulate(num_samples)) batch_samplers = [iter(sampler) for sampler in self.batch_samplers] for dataset_idx in cycle(range(len(batch_samplers))): sample_offset = sample_offsets[dataset_idx] try: yield [idx + sample_offset for idx in next(batch_samplers[dataset_idx])] except StopIteration: # current iterator is apparently exhausted break def __len__(self) -> int: return min([len(sampler) for sampler in self.batch_samplers]) * len(self.batch_samplers) class ProportionalBatchSampler(SetEpochMixin, BatchSampler): def __init__( self, dataset: ConcatDataset, batch_samplers: List[BatchSampler], generator: torch.Generator, seed: int, ): super().__init__(dataset, batch_samplers[0].batch_size, batch_samplers[0].drop_last) self.dataset = dataset self.batch_samplers = batch_samplers self.generator = generator self.seed = seed def __iter__(self): self.generator.manual_seed(self.seed + self.epoch) num_samples = [len(dataset) for dataset in self.dataset.datasets] sample_offsets = [0] + list(accumulate(num_samples)) num_batches = [len(sampler) for sampler in self.batch_samplers] dataset_indices = [idx for idx, length in enumerate(num_batches) for _ in range(length)] dataset_idx_sampler = SubsetRandomSampler(dataset_indices, generator=self.generator) batch_samplers = [iter(sampler) for sampler in self.batch_samplers] for dataset_idx in dataset_idx_sampler: sample_offset = sample_offsets[dataset_idx] yield [idx + sample_offset for idx in next(batch_samplers[dataset_idx])] def __len__(self) -> int: return sum([len(sampler) for sampler in self.batch_samplers])