SentenceLabelDataset.py 4.04 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
import logging
from typing import List
Rayyyyy's avatar
Rayyyyy committed
3
4

import numpy as np
Rayyyyy's avatar
Rayyyyy committed
5
6
7
from torch.utils.data import IterableDataset

from sentence_transformers.readers import InputExample
Rayyyyy's avatar
Rayyyyy committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

logger = logging.getLogger(__name__)


class SentenceLabelDataset(IterableDataset):
    """
    This dataset can be used for some specific Triplet Losses like BATCH_HARD_TRIPLET_LOSS which requires
    multiple examples with the same label in a batch.

    It draws n consecutive, random and unique samples from one label at a time. This is repeated for each label.

    Labels with fewer than n unique samples are ignored.
    This also applied to drawing without replacement, once less than n samples remain for a label, it is skipped.

    This *DOES NOT* check if there are more labels than the batch is large or if the batch size is divisible
    by the samples drawn per label.
    """

    def __init__(self, examples: List[InputExample], samples_per_label: int = 2, with_replacement: bool = False):
        """
        Creates a LabelSampler for a SentenceLabelDataset.

Rayyyyy's avatar
Rayyyyy committed
30
31
32
33
34
35
36
        Args:
            examples (List[InputExample]): A list of InputExamples.
            samples_per_label (int, optional): The number of consecutive, random, and unique samples drawn per label.
                The batch size should be a multiple of samples_per_label. Defaults to 2.
            with_replacement (bool, optional): If True, each sample is drawn at most once (depending on the total number
                of samples per label). If False, one sample can be drawn in multiple draws, but not multiple times in
                the same drawing. Defaults to False.
Rayyyyy's avatar
Rayyyyy committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        """
        super().__init__()

        self.samples_per_label = samples_per_label

        # Group examples by label
        label2ex = {}
        for example in examples:
            if example.label not in label2ex:
                label2ex[example.label] = []
            label2ex[example.label].append(example)

        # Include only labels with at least 2 examples
        self.grouped_inputs = []
        self.groups_right_border = []
        num_labels = 0

        for label, label_examples in label2ex.items():
            if len(label_examples) >= self.samples_per_label:
                self.grouped_inputs.extend(label_examples)
                self.groups_right_border.append(
                    len(self.grouped_inputs)
                )  # At which position does this label group / bucket end?
                num_labels += 1

        self.label_range = np.arange(num_labels)
        self.with_replacement = with_replacement
        np.random.shuffle(self.label_range)

        logger.info(
            "SentenceLabelDataset: {} examples, from which {} examples could be used (those labels appeared at least {} times). {} different labels found.".format(
                len(examples), len(self.grouped_inputs), self.samples_per_label, num_labels
            )
        )

    def __iter__(self):
        label_idx = 0
        count = 0
        already_seen = {}
        while count < len(self.grouped_inputs):
            label = self.label_range[label_idx]
            if label not in already_seen:
                already_seen[label] = set()

            left_border = 0 if label == 0 else self.groups_right_border[label - 1]
            right_border = self.groups_right_border[label]

            if self.with_replacement:
                selection = np.arange(left_border, right_border)
            else:
                selection = [i for i in np.arange(left_border, right_border) if i not in already_seen[label]]

            if len(selection) >= self.samples_per_label:
                for element_idx in np.random.choice(selection, self.samples_per_label, replace=False):
                    count += 1
                    already_seen[label].add(element_idx)
                    yield self.grouped_inputs[element_idx]

            label_idx += 1
            if label_idx >= len(self.label_range):
                label_idx = 0
                already_seen = {}
                np.random.shuffle(self.label_range)

    def __len__(self):
        return len(self.grouped_inputs)