MSELoss.py 4.16 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
from typing import Dict, Iterable

import torch
from torch import Tensor, nn
Rayyyyy's avatar
Rayyyyy committed
5
6
7
8
9
10
11
12
13
14
15


class MSELoss(nn.Module):
    def __init__(self, model):
        """
        Computes the MSE loss between the computed sentence embedding and a target sentence embedding. This loss
        is used when extending sentence embeddings to new languages as described in our publication
        Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation.

        For an example, see `the distillation documentation <../../examples/training/distillation/README.html>`_ on extending language models to new languages.

Rayyyyy's avatar
Rayyyyy committed
16
17
        Args:
            model: SentenceTransformerModel
Rayyyyy's avatar
Rayyyyy committed
18
19
20
21
22
23
24
25
26
27
28
29
30

        References:
            - Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation: https://arxiv.org/abs/2004.09813
            - `Training > Model Distillation <../../examples/training/distillation/README.html>`_
            - `Training > Multilingual Models <../../examples/training/multilingual/README.html>`_

        Requirements:
            1. Usually uses a finetuned teacher M in a knowledge distillation setup

        Relations:
            - :class:`MarginMSELoss` is equivalent to this loss, but with a margin through a negative pair.

        Input:
Rayyyyy's avatar
Rayyyyy committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
            +-----------------------------------------+-----------------------------+
            | Texts                                   | Labels                      |
            +=========================================+=============================+
            | sentence                                | model sentence embeddings   |
            +-----------------------------------------+-----------------------------+
            | sentence_1, sentence_2, ..., sentence_N | model sentence embeddings   |
            +-----------------------------------------+-----------------------------+

        Example:
            ::

                from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
                from datasets import Dataset

                student_model = SentenceTransformer("microsoft/mpnet-base")
                teacher_model = SentenceTransformer("all-mpnet-base-v2")
                train_dataset = Dataset.from_dict({
                    "english": ["The first sentence",  "The second sentence", "The third sentence",  "The fourth sentence"],
                    "french": ["La première phrase",  "La deuxième phrase", "La troisième phrase",  "La quatrième phrase"],
                })

                def compute_labels(batch):
                    return {
                        "label": teacher_model.encode(batch["english"])
                    }

                train_dataset = train_dataset.map(compute_labels, batched=True)
                loss = losses.MSELoss(student_model)

                trainer = SentenceTransformerTrainer(
                    model=student_model,
                    train_dataset=train_dataset,
                    loss=loss,
                )
                trainer.train()
Rayyyyy's avatar
Rayyyyy committed
66
67
68
69
70
71
        """
        super(MSELoss, self).__init__()
        self.model = model
        self.loss_fct = nn.MSELoss()

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
Rayyyyy's avatar
Rayyyyy committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
        # Concatenate multiple inputs on the batch dimension
        if len(sentence_features) > 1:
            embeddings = torch.cat([self.model(inputs)["sentence_embedding"] for inputs in sentence_features], dim=0)
            # Repeat the labels for each input
            return self.loss_fct(embeddings, labels.repeat(len(sentence_features), 1))

        embeddings = self.model(sentence_features[0])["sentence_embedding"]
        return self.loss_fct(embeddings, labels)

    @property
    def citation(self) -> str:
        return """
@inproceedings{reimers-2020-multilingual-sentence-bert,
    title = "Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2020",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/2004.09813",
}
"""