MegaBatchMarginLoss.py 6.63 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from .. import util
import torch
from torch import nn, Tensor
from typing import Iterable, Dict
import torch.nn.functional as F


class MegaBatchMarginLoss(nn.Module):
    def __init__(
        self,
        model,
        positive_margin: float = 0.8,
        negative_margin: float = 0.3,
        use_mini_batched_version: bool = True,
        mini_batch_size: int = 50,
    ):
        """
        Given a large batch (like 500 or more examples) of (anchor_i, positive_i) pairs, find for each pair in the batch
        the hardest negative, i.e. find j != i such that cos_sim(anchor_i, positive_j) is maximal. Then create from this a
        triplet (anchor_i, positive_i, positive_j) where positive_j serves as the negative for this triplet.

        Then train as with the triplet loss.

        :param model: SentenceTransformerModel
        :param positive_margin: Positive margin, cos(anchor, positive) should be > positive_margin
        :param negative_margin: Negative margin, cos(anchor, negative) should be < negative_margin
        :param use_mini_batched_version: As large batch sizes require a lot of memory, we can use a mini-batched version.
            We break down the large batch into smaller batches with fewer examples.
        :param mini_batch_size: Size for the mini-batches. Should be a devisor for the batch size in your data loader.

        References:
            - This loss function was inspired by the ParaNMT paper: https://www.aclweb.org/anthology/P18-1042/

        Requirements:
            1. (anchor, positive) pairs
            2. Large batches (500 or more examples)

        Input:
            +---------------------------------------+--------+
            | Texts                                 | Labels |
            +=======================================+========+
            | (anchor, positive) pairs              | none   |
            +---------------------------------------+--------+

        Example:
            ::

                from sentence_transformers import SentenceTransformer, InputExample, losses
                from torch.utils.data import DataLoader

                model = SentenceTransformer('all-MiniLM-L6-v2')

                total_examples = 500
                train_batch_size = 250
                train_mini_batch_size = 32

                train_examples = [
                    InputExample(texts=[f"This is sentence number {i}", f"This is sentence number {i+1}"]) for i in range(total_examples)
                ]
                train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
                train_loss = losses.MegaBatchMarginLoss(model=model, mini_batch_size=train_mini_batch_size)

                model.fit(
                    [(train_dataloader, train_loss)],
                    epochs=10,
                )
        """
        super(MegaBatchMarginLoss, self).__init__()
        self.model = model
        self.positive_margin = positive_margin
        self.negative_margin = negative_margin
        self.mini_batch_size = mini_batch_size
        self.forward = self.forward_mini_batched if use_mini_batched_version else self.forward_non_mini_batched

    def forward_mini_batched(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        anchor, positive = sentence_features
        feature_names = list(anchor.keys())

        with torch.no_grad():
            self.model.eval()
            all_positive_emb = self.model(positive)["sentence_embedding"].detach()
            self.model.train()

        diagonal_matrix = torch.eye(len(all_positive_emb), len(all_positive_emb), device=all_positive_emb.device)

        # Iterate over the triplets (anchor, positive, hardest_negative) in smaller mini_batch sizes
        for start_idx in range(0, len(all_positive_emb), self.mini_batch_size):
            end_idx = start_idx + self.mini_batch_size
            anchor_emb = self.model({key: anchor[key][start_idx:end_idx] for key in feature_names})[
                "sentence_embedding"
            ]

            # Find hard negatives. For each anchor, find the hardest negative
            # Store them in the triplets (anchor, positive, hardest_negative)
            hard_negative_features = {key: [] for key in feature_names}
            with torch.no_grad():
                cos_scores = util.pytorch_cos_sim(anchor_emb, all_positive_emb)
                negative_scores = (
                    cos_scores - 2 * diagonal_matrix[start_idx:end_idx]
                )  # Remove positive scores along the diagonal, set them to -1 so that they are not selected by the max() operation
                negatives_max, negatives_ids = torch.max(negative_scores, dim=1)

            for hard_negative_id in negatives_ids:
                for key in feature_names:
                    hard_negative_features[key].append(positive[key][hard_negative_id])

            for key in feature_names:
                hard_negative_features[key] = torch.stack(hard_negative_features[key])

            # Compute differentiable negative and positive embeddings
            positive_emb = self.model({key: positive[key][start_idx:end_idx] for key in feature_names})[
                "sentence_embedding"
            ]
            negative_emb = self.model(hard_negative_features)["sentence_embedding"]

            assert anchor_emb.shape == positive_emb.shape
            assert anchor_emb.shape == negative_emb.shape

            # Compute loss
            pos_cosine = F.cosine_similarity(anchor_emb, positive_emb)
            neg_cosine = F.cosine_similarity(anchor_emb, negative_emb)
            losses = F.relu(self.positive_margin - pos_cosine) + F.relu(neg_cosine - self.negative_margin)
            losses = losses.mean()

            # Backpropagate unless it is the last mini batch. The last mini-batch will be back propagated by the outside train loop
            if end_idx < len(cos_scores):
                losses.backward()

        return losses

    ##### Non mini-batched version ###
    def forward_non_mini_batched(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
        embeddings_a, embeddings_b = reps

        cos_scores = util.pytorch_cos_sim(embeddings_a, embeddings_b)
        positive_scores = torch.diagonal(cos_scores)
        negative_scores = cos_scores - (
            2 * torch.eye(*cos_scores.shape, device=cos_scores.device)
        )  # Remove positive scores along the diagonal
        negatives_max, _ = torch.max(negative_scores, dim=1)
        losses = F.relu(self.positive_margin - positive_scores) + F.relu(negatives_max - self.negative_margin)
        return losses.mean()