SoftmaxLoss.py 5.86 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
import logging
Rayyyyy's avatar
Rayyyyy committed
2
3
4
5
from typing import Callable, Dict, Iterable

import torch
from torch import Tensor, nn
Rayyyyy's avatar
Rayyyyy committed
6

Rayyyyy's avatar
Rayyyyy committed
7
from sentence_transformers.SentenceTransformer import SentenceTransformer
Rayyyyy's avatar
Rayyyyy committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

logger = logging.getLogger(__name__)


class SoftmaxLoss(nn.Module):
    def __init__(
        self,
        model: SentenceTransformer,
        sentence_embedding_dimension: int,
        num_labels: int,
        concatenation_sent_rep: bool = True,
        concatenation_sent_difference: bool = True,
        concatenation_sent_multiplication: bool = False,
        loss_fct: Callable = nn.CrossEntropyLoss(),
    ):
        """
        This loss was used in our SBERT publication (https://arxiv.org/abs/1908.10084) to train the SentenceTransformer
        model on NLI data. It adds a softmax classifier on top of the output of two transformer networks.

        :class:`MultipleNegativesRankingLoss` is an alternative loss function that often yields better results,
        as per https://arxiv.org/abs/2004.09813.

Rayyyyy's avatar
Rayyyyy committed
30
31
32
33
34
35
36
37
        Args:
            model (SentenceTransformer): The SentenceTransformer model.
            sentence_embedding_dimension (int): The dimension of the sentence embeddings.
            num_labels (int): The number of different labels.
            concatenation_sent_rep (bool): Whether to concatenate vectors u,v for the softmax classifier. Defaults to True.
            concatenation_sent_difference (bool): Whether to add abs(u-v) for the softmax classifier. Defaults to True.
            concatenation_sent_multiplication (bool): Whether to add u*v for the softmax classifier. Defaults to False.
            loss_fct (Callable): Custom pytorch loss function. If not set, uses nn.CrossEntropyLoss(). Defaults to nn.CrossEntropyLoss().
Rayyyyy's avatar
Rayyyyy committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

        References:
            - Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks: https://arxiv.org/abs/1908.10084
            - `Training Examples > Natural Language Inference <../../examples/training/nli/README.html>`_

        Requirements:
            1. sentence pairs with a class label

        Inputs:
            +---------------------------------------+--------+
            | Texts                                 | Labels |
            +=======================================+========+
            | (sentence_A, sentence_B) pairs        | class  |
            +---------------------------------------+--------+

        Example:
            ::

Rayyyyy's avatar
Rayyyyy committed
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
                from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
                from datasets import Dataset

                model = SentenceTransformer("microsoft/mpnet-base")
                train_dataset = Dataset.from_dict({
                    "sentence1": [
                        "A person on a horse jumps over a broken down airplane.",
                        "A person on a horse jumps over a broken down airplane.",
                        "A person on a horse jumps over a broken down airplane.",
                        "Children smiling and waving at camera",
                    ],
                    "sentence2": [
                        "A person is training his horse for a competition.",
                        "A person is at a diner, ordering an omelette.",
                        "A person is outdoors, on a horse.",
                        "There are children present.",
                    ],
                    "label": [1, 2, 0, 0],
                })
                loss = losses.SoftmaxLoss(model, model.get_sentence_embedding_dimension(), num_labels=3)

                trainer = SentenceTransformerTrainer(
Rayyyyy's avatar
Rayyyyy committed
78
                    model=model,
Rayyyyy's avatar
Rayyyyy committed
79
80
                    train_dataset=train_dataset,
                    loss=loss,
Rayyyyy's avatar
Rayyyyy committed
81
                )
Rayyyyy's avatar
Rayyyyy committed
82
                trainer.train()
Rayyyyy's avatar
Rayyyyy committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
        """
        super(SoftmaxLoss, self).__init__()
        self.model = model
        self.num_labels = num_labels
        self.concatenation_sent_rep = concatenation_sent_rep
        self.concatenation_sent_difference = concatenation_sent_difference
        self.concatenation_sent_multiplication = concatenation_sent_multiplication

        num_vectors_concatenated = 0
        if concatenation_sent_rep:
            num_vectors_concatenated += 2
        if concatenation_sent_difference:
            num_vectors_concatenated += 1
        if concatenation_sent_multiplication:
            num_vectors_concatenated += 1
        logger.info("Softmax loss: #Vectors concatenated: {}".format(num_vectors_concatenated))
        self.classifier = nn.Linear(
            num_vectors_concatenated * sentence_embedding_dimension, num_labels, device=model.device
        )
        self.loss_fct = loss_fct

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
        rep_a, rep_b = reps

        vectors_concat = []
        if self.concatenation_sent_rep:
            vectors_concat.append(rep_a)
            vectors_concat.append(rep_b)

        if self.concatenation_sent_difference:
            vectors_concat.append(torch.abs(rep_a - rep_b))

        if self.concatenation_sent_multiplication:
            vectors_concat.append(rep_a * rep_b)

        features = torch.cat(vectors_concat, 1)

        output = self.classifier(features)

        if labels is not None:
            loss = self.loss_fct(output, labels.view(-1))
            return loss
        else:
            return reps, output
Rayyyyy's avatar
Rayyyyy committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141

    @property
    def citation(self) -> str:
        return """
@inproceedings{reimers-2019-sentence-bert,
    title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
    author = "Reimers, Nils and Gurevych, Iryna",
    booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
    month = "11",
    year = "2019",
    publisher = "Association for Computational Linguistics",
    url = "https://arxiv.org/abs/1908.10084",
}
"""