MarginMSELoss.py 4.49 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from .. import util
from torch import nn, Tensor
from typing import Iterable, Dict


class MarginMSELoss(nn.Module):
    def __init__(self, model, similarity_fct=util.pairwise_dot_score):
        """
        Compute the MSE loss between the ``|sim(Query, Pos) - sim(Query, Neg)|`` and ``|gold_sim(Query, Pos) - gold_sim(Query, Neg)|``.
        By default, sim() is the dot-product. The gold_sim is often the similarity score from a teacher model.

        In contrast to :class:`MultipleNegativesRankingLoss`, the two passages do not have to be strictly positive and negative,
        both can be relevant or not relevant for a given query. This can be an advantage of MarginMSELoss over
        MultipleNegativesRankingLoss, but note that the MarginMSELoss is much slower to train. With MultipleNegativesRankingLoss,
        with a batch size of 64, we compare one query against 128 passages. With MarginMSELoss, we compare a query only
        against two passages.

        :param model: SentenceTransformerModel
        :param similarity_fct: Which similarity function to use.

        References:
            - For more details, please refer to https://arxiv.org/abs/2010.02666.
            - `Training Examples > MS MARCO <../../examples/training/ms_marco/README.html>`_
            - `Unsupervised Learning > Domain Adaptation <../../examples/domain_adaptation/README.html>`_

        Requirements:
            1. (query, passage_one, passage_two) triplets
            2. Usually used with a finetuned teacher M in a knowledge distillation setup

        Relations:
            - :class:`MSELoss` is equivalent to this loss, but without a margin through the negative pair.

        Inputs:
            +-----------------------------------------------+-----------------------------------------------+
            | Texts                                         | Labels                                        |
            +===============================================+===============================================+
            | (query, passage_one, passage_two) triplets    | M(query, passage_one) - M(query, passage_two) |
            +-----------------------------------------------+-----------------------------------------------+

        Example:
            ::

                from sentence_transformers import SentenceTransformer, InputExample, losses
                from sentence_transformers.util import pairwise_dot_score
                from torch.utils.data import DataLoader
                import torch

                student_model = SentenceTransformer('sentence-transformers/distilbert-base-nli-mean-tokens')
                teacher_model = SentenceTransformer('sentence-transformers/bert-base-nli-stsb-mean-tokens')

                train_examples = [
                    ['The first query',  'The first positive passage',  'The first negative passage'],
                    ['The second query', 'The second positive passage', 'The second negative passage'],
                    ['The third query',  'The third positive passage',  'The third negative passage'],
                ]
                train_batch_size = 1
                encoded = torch.tensor([teacher_model.encode(x).tolist() for x in train_examples])
                labels = pairwise_dot_score(encoded[:, 0], encoded[:, 1]) - pairwise_dot_score(encoded[:, 0], encoded[:, 2])

                train_input_examples = [InputExample(texts=x, label=labels[i]) for i, x in enumerate(train_examples)]
                train_dataloader = DataLoader(train_input_examples, shuffle=True, batch_size=train_batch_size)
                train_loss = losses.MarginMSELoss(model=student_model)

                student_model.fit(
                    [(train_dataloader, train_loss)],
                    epochs=10,
                )
        """
        super(MarginMSELoss, self).__init__()
        self.model = model
        self.similarity_fct = similarity_fct
        self.loss_fct = nn.MSELoss()

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        # sentence_features: query, positive passage, negative passage
        reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
        embeddings_query = reps[0]
        embeddings_pos = reps[1]
        embeddings_neg = reps[2]

        scores_pos = self.similarity_fct(embeddings_query, embeddings_pos)
        scores_neg = self.similarity_fct(embeddings_query, embeddings_neg)
        margin_pred = scores_pos - scores_neg

        return self.loss_fct(margin_pred, labels)