CosineSimilarityLoss.py 3.62 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
from typing import Any, Dict, Iterable

Rayyyyy's avatar
Rayyyyy committed
3
import torch
Rayyyyy's avatar
Rayyyyy committed
4
5
6
7
from torch import Tensor, nn

from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.util import fullname
Rayyyyy's avatar
Rayyyyy committed
8
9
10
11
12
13
14
15
16


class CosineSimilarityLoss(nn.Module):
    def __init__(self, model: SentenceTransformer, loss_fct=nn.MSELoss(), cos_score_transformation=nn.Identity()):
        """
        CosineSimilarityLoss expects that the InputExamples consists of two texts and a float label. It computes the
        vectors ``u = model(sentence_A)`` and ``v = model(sentence_B)`` and measures the cosine-similarity between the two.
        By default, it minimizes the following loss: ``||input_label - cos_score_transformation(cosine_sim(u,v))||_2``.

Rayyyyy's avatar
Rayyyyy committed
17
18
19
20
21
22
23
24
25
        Args:
            model: SentenceTransformer model
            loss_fct: Which pytorch loss function should be used to
                compare the ``cosine_similarity(u, v)`` with the
                input_label? By default, MSE is used: ``||input_label -
                cosine_sim(u, v)||_2``
            cos_score_transformation: The cos_score_transformation
                function is applied on top of cosine_similarity. By
                default, the identify function is used (i.e. no change).
Rayyyyy's avatar
Rayyyyy committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

        References:
            - `Training Examples > Semantic Textual Similarity <../../examples/training/sts/README.html>`_

        Requirements:
            1. Sentence pairs with corresponding similarity scores in range `[0, 1]`

        Relations:
            - :class:`CoSENTLoss` seems to produce a stronger training signal than CosineSimilarityLoss. In our experiments, CoSENTLoss is recommended.
            - :class:`AnglELoss` is :class:`CoSENTLoss` with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``. It also produces a stronger training signal than CosineSimilarityLoss.

        Inputs:
            +--------------------------------+------------------------+
            | Texts                          | Labels                 |
            +================================+========================+
            | (sentence_A, sentence_B) pairs | float similarity score |
            +--------------------------------+------------------------+

        Example:
            ::

Rayyyyy's avatar
Rayyyyy committed
47
48
                from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
                from datasets import Dataset
Rayyyyy's avatar
Rayyyyy committed
49

Rayyyyy's avatar
Rayyyyy committed
50
51
52
53
54
55
56
                model = SentenceTransformer("microsoft/mpnet-base")
                train_dataset = Dataset.from_dict({
                    "sentence1": ["It's nice weather outside today.", "He drove to work."],
                    "sentence2": ["It's so sunny.", "She walked to the store."],
                    "score": [1.0, 0.3],
                })
                loss = losses.CosineSimilarityLoss(model)
Rayyyyy's avatar
Rayyyyy committed
57

Rayyyyy's avatar
Rayyyyy committed
58
59
60
61
                trainer = SentenceTransformerTrainer(
                    model=model,
                    train_dataset=train_dataset,
                    loss=loss,
Rayyyyy's avatar
Rayyyyy committed
62
                )
Rayyyyy's avatar
Rayyyyy committed
63
                trainer.train()
Rayyyyy's avatar
Rayyyyy committed
64
65
66
67
68
69
70
71
72
        """
        super(CosineSimilarityLoss, self).__init__()
        self.model = model
        self.loss_fct = loss_fct
        self.cos_score_transformation = cos_score_transformation

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
        output = self.cos_score_transformation(torch.cosine_similarity(embeddings[0], embeddings[1]))
Rayyyyy's avatar
Rayyyyy committed
73
74
75
76
        return self.loss_fct(output, labels.float().view(-1))

    def get_config_dict(self) -> Dict[str, Any]:
        return {"loss_fct": fullname(self.loss_fct)}