ContrastiveTensionLoss.py 10.7 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
import copy
import math
Rayyyyy's avatar
Rayyyyy committed
3
4
5
import random
from typing import Dict, Iterable

Rayyyyy's avatar
Rayyyyy committed
6
import numpy as np
Rayyyyy's avatar
Rayyyyy committed
7
8
9
10
11
import torch
from torch import Tensor, nn

from sentence_transformers import InputExample, util
from sentence_transformers.SentenceTransformer import SentenceTransformer
Rayyyyy's avatar
Rayyyyy committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26


class ContrastiveTensionLoss(nn.Module):
    """
    This loss expects only single sentences, without any labels. Positive and negative pairs are automatically created via random sampling,
    such that a positive pair consists of two identical sentences and a negative pair consists of two different sentences. An independent
    copy of the encoder model is created, which is used for encoding the first sentence of each pair. The original encoder model encodes the
    second sentence. The embeddings are compared and scored using the generated labels (1 if positive, 0 if negative) using the binary cross
    entropy objective.

    Note that you must use the `ContrastiveTensionDataLoader` for this loss. The `pos_neg_ratio` of the ContrastiveTensionDataLoader can be
    used to determine the number of negative pairs per positive pair.

    Generally, :class:`ContrastiveTensionLossInBatchNegatives` is recommended over this loss, as it gives a stronger training signal.

Rayyyyy's avatar
Rayyyyy committed
27
28
    Args:
        model: SentenceTransformer model
Rayyyyy's avatar
Rayyyyy committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

    References:
        * Semantic Re-Tuning with Contrastive Tension: https://openreview.net/pdf?id=Ov_sMNau-PF
        * `Unsupervised Learning > CT <../../examples/unsupervised_learning/CT/README.html>`_

    Relations:
        * :class:`ContrastiveTensionLossInBatchNegatives` uses in-batch negative sampling, which gives a stronger training signal than this loss.

    Inputs:
        +------------------+--------+
        | Texts            | Labels |
        +==================+========+
        | single sentences | none   |
        +------------------+--------+

    Example:
        ::

            from sentence_transformers import SentenceTransformer, losses
            from sentence_transformers.losses import ContrastiveTensionDataLoader

            model = SentenceTransformer('all-MiniLM-L6-v2')
            train_examples = [
                'This is the 1st sentence',
                'This is the 2nd sentence',
                'This is the 3rd sentence',
                'This is the 4th sentence',
                'This is the 5th sentence',
                'This is the 6th sentence',
                'This is the 7th sentence',
                'This is the 8th sentence',
                'This is the 9th sentence',
                'This is the final sentence',
            ]

            train_dataloader = ContrastiveTensionDataLoader(train_examples, batch_size=3, pos_neg_ratio=3)
            train_loss = losses.ContrastiveTensionLoss(model=model)

            model.fit(
                [(train_dataloader, train_loss)],
                epochs=10,
            )
    """

    def __init__(self, model: SentenceTransformer):
        super(ContrastiveTensionLoss, self).__init__()
        self.model2 = model  # This will be the final model used during the inference time.
        self.model1 = copy.deepcopy(model)
        self.criterion = nn.BCEWithLogitsLoss(reduction="sum")

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        sentence_features1, sentence_features2 = tuple(sentence_features)
        reps_1 = self.model1(sentence_features1)["sentence_embedding"]  # (bsz, hdim)
        reps_2 = self.model2(sentence_features2)["sentence_embedding"]

        sim_scores = (
            torch.matmul(reps_1[:, None], reps_2[:, :, None]).squeeze(-1).squeeze(-1)
        )  # (bsz,) dot product, i.e. S1S2^T

        loss = self.criterion(sim_scores, labels.type_as(sim_scores))
        return loss

Rayyyyy's avatar
Rayyyyy committed
91
92
93
94
95
96
97
98
99
100
101
102
    @property
    def citation(self) -> str:
        return """
@inproceedings{carlsson2021semantic,
    title={Semantic Re-tuning with Contrastive Tension},
    author={Fredrik Carlsson and Amaru Cuba Gyllensten and Evangelia Gogoulou and Erik Ylip{\"a}{\"a} Hellqvist and Magnus Sahlgren},
    booktitle={International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=Ov_sMNau-PF}
}
"""

Rayyyyy's avatar
Rayyyyy committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116

class ContrastiveTensionLossInBatchNegatives(nn.Module):
    def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.cos_sim):
        """
        This loss expects only single sentences, without any labels. Positive and negative pairs are automatically created via random sampling,
        such that a positive pair consists of two identical sentences and a negative pair consists of two different sentences. An independent
        copy of the encoder model is created, which is used for encoding the first sentence of each pair. The original encoder model encodes the
        second sentence. Unlike :class:`ContrastiveTensionLoss`, this loss uses the batch negative sampling strategy, i.e. the negative pairs
        are sampled from the batch. Using in-batch negative sampling gives a stronger training signal than the original :class:`ContrastiveTensionLoss`.
        The performance usually increases with increasing batch sizes.

        Note that you should not use the `ContrastiveTensionDataLoader` for this loss, but just a normal DataLoader with `InputExample` instances.
        The two texts of each `InputExample` instance should be identical.

Rayyyyy's avatar
Rayyyyy committed
117
118
119
120
121
122
123
        Args:
            model: SentenceTransformer model
            scale: Output of similarity function is multiplied by scale
                value
            similarity_fct: similarity function between sentence
                embeddings. By default, cos_sim. Can also be set to dot
                product (and then set scale to 1)
Rayyyyy's avatar
Rayyyyy committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

        References:
            - Semantic Re-Tuning with Contrastive Tension: https://openreview.net/pdf?id=Ov_sMNau-PF
            - `Unsupervised Learning > CT (In-Batch Negatives) <../../examples/unsupervised_learning/CT_In-Batch_Negatives/README.html>`_

        Relations:
            * :class:`ContrastiveTensionLoss` does not select negative pairs in-batch, resulting in a weaker training signal than this loss.

        Inputs:
            +------------------------+--------+
            | Texts                  | Labels |
            +========================+========+
            | (anchor, anchor) pairs | none   |
            +------------------------+--------+

        Example:
            ::

                from sentence_transformers import SentenceTransformer, losses
                from torch.utils.data import DataLoader

                model = SentenceTransformer('all-MiniLM-L6-v2')
                train_examples = [
                    InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
                    InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0),
                ]
                train_examples = [
                    InputExample(texts=['This is the 1st sentence', 'This is the 1st sentence']),
                    InputExample(texts=['This is the 2nd sentence', 'This is the 2nd sentence']),
                    InputExample(texts=['This is the 3rd sentence', 'This is the 3rd sentence']),
                    InputExample(texts=['This is the 4th sentence', 'This is the 4th sentence']),
                    InputExample(texts=['This is the 5th sentence', 'This is the 5th sentence']),
                ]

                train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
                train_loss = losses.ContrastiveTensionLossInBatchNegatives(model=model)

                model.fit(
                    [(train_dataloader, train_loss)],
                    epochs=10,
                )
        """
        super(ContrastiveTensionLossInBatchNegatives, self).__init__()
        self.model2 = model  # This will be the final model used during the inference time.
        self.model1 = copy.deepcopy(model)
        self.similarity_fct = similarity_fct
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(scale))

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        sentence_features1, sentence_features2 = tuple(sentence_features)
        embeddings_a = self.model1(sentence_features1)["sentence_embedding"]  # (bsz, hdim)
        embeddings_b = self.model2(sentence_features2)["sentence_embedding"]

        scores = self.similarity_fct(embeddings_a, embeddings_b) * self.logit_scale.exp()  # self.scale
        labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)
        return (self.cross_entropy_loss(scores, labels) + self.cross_entropy_loss(scores.t(), labels)) / 2

Rayyyyy's avatar
Rayyyyy committed
182
183
184
185
186
187
188
189
190
191
192
193
    @property
    def citation(self) -> str:
        return """
@inproceedings{carlsson2021semantic,
    title={Semantic Re-tuning with Contrastive Tension},
    author={Fredrik Carlsson and Amaru Cuba Gyllensten and Evangelia Gogoulou and Erik Ylip{\"a}{\"a} Hellqvist and Magnus Sahlgren},
    booktitle={International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=Ov_sMNau-PF}
}
"""

Rayyyyy's avatar
Rayyyyy committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236

################# CT Data Loader #################
# For CT, we need batches in a specific format
# In each batch, we have one positive pair (i.e. [sentA, sentA]) and 7 negative pairs (i.e. [sentA, sentB]).
# To achieve this, we create a custom DataLoader that produces batches with this property


class ContrastiveTensionDataLoader:
    def __init__(self, sentences, batch_size, pos_neg_ratio=8):
        self.sentences = sentences
        self.batch_size = batch_size
        self.pos_neg_ratio = pos_neg_ratio
        self.collate_fn = None

        if self.batch_size % self.pos_neg_ratio != 0:
            raise ValueError(
                f"ContrastiveTensionDataLoader was loaded with a pos_neg_ratio of {pos_neg_ratio} and a batch size of {batch_size}. The batch size must be divisible by the pos_neg_ratio"
            )

    def __iter__(self):
        random.shuffle(self.sentences)
        sentence_idx = 0
        batch = []

        while sentence_idx + 1 < len(self.sentences):
            s1 = self.sentences[sentence_idx]
            if len(batch) % self.pos_neg_ratio > 0:  # Negative (different) pair
                sentence_idx += 1
                s2 = self.sentences[sentence_idx]
                label = 0
            else:  # Positive (identical pair)
                s2 = self.sentences[sentence_idx]
                label = 1

            sentence_idx += 1
            batch.append(InputExample(texts=[s1, s2], label=label))

            if len(batch) >= self.batch_size:
                yield self.collate_fn(batch) if self.collate_fn is not None else batch
                batch = []

    def __len__(self):
        return math.floor(len(self.sentences) / (2 * self.batch_size))