ContrastiveLoss.py 4.68 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
from enum import Enum
Rayyyyy's avatar
Rayyyyy committed
2
3
from typing import Dict, Iterable

Rayyyyy's avatar
Rayyyyy committed
4
import torch.nn.functional as F
Rayyyyy's avatar
Rayyyyy committed
5
6
from torch import Tensor, nn

Rayyyyy's avatar
Rayyyyy committed
7
8
9
10
from sentence_transformers.SentenceTransformer import SentenceTransformer


class SiameseDistanceMetric(Enum):
Rayyyyy's avatar
Rayyyyy committed
11
    """The metric for the contrastive loss"""
Rayyyyy's avatar
Rayyyyy committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

    EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
    MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
    COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y)


class ContrastiveLoss(nn.Module):
    def __init__(
        self,
        model: SentenceTransformer,
        distance_metric=SiameseDistanceMetric.COSINE_DISTANCE,
        margin: float = 0.5,
        size_average: bool = True,
    ):
        """
        Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
        two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.

Rayyyyy's avatar
Rayyyyy committed
30
31
32
33
34
35
36
37
        Args:
            model: SentenceTransformer model
            distance_metric: Function that returns a distance between
                two embeddings. The class SiameseDistanceMetric contains
                pre-defined metrices that can be used
            margin: Negative samples (label == 0) should have a distance
                of at least the margin value.
            size_average: Average by the size of the mini-batch.
Rayyyyy's avatar
Rayyyyy committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

        References:
            * Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
            * `Training Examples > Quora Duplicate Questions <../../examples/training/quora_duplicate_questions/README.html>`_

        Requirements:
            1. (anchor, positive/negative) pairs

        Relations:
            - :class:`OnlineContrastiveLoss` is similar, but uses hard positive and hard negative pairs.
            It often yields better results.

        Inputs:
            +-----------------------------------------------+------------------------------+
            | Texts                                         | Labels                       |
            +===============================================+==============================+
            | (anchor, positive/negative) pairs             | 1 if positive, 0 if negative |
            +-----------------------------------------------+------------------------------+

        Example:
            ::

Rayyyyy's avatar
Rayyyyy committed
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
                from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
                from datasets import Dataset

                model = SentenceTransformer("microsoft/mpnet-base")
                train_dataset = Dataset.from_dict({
                    "sentence1": ["It's nice weather outside today.", "He drove to work."],
                    "sentence2": ["It's so sunny.", "She walked to the store."],
                    "label": [1, 0],
                })
                loss = losses.ContrastiveLoss(model)

                trainer = SentenceTransformerTrainer(
                    model=model,
                    train_dataset=train_dataset,
                    loss=loss,
Rayyyyy's avatar
Rayyyyy committed
75
                )
Rayyyyy's avatar
Rayyyyy committed
76
                trainer.train()
Rayyyyy's avatar
Rayyyyy committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        """
        super(ContrastiveLoss, self).__init__()
        self.distance_metric = distance_metric
        self.margin = margin
        self.model = model
        self.size_average = size_average

    def get_config_dict(self):
        distance_metric_name = self.distance_metric.__name__
        for name, value in vars(SiameseDistanceMetric).items():
            if value == self.distance_metric:
                distance_metric_name = "SiameseDistanceMetric.{}".format(name)
                break

        return {"distance_metric": distance_metric_name, "margin": self.margin, "size_average": self.size_average}

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
        reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
        assert len(reps) == 2
        rep_anchor, rep_other = reps
        distances = self.distance_metric(rep_anchor, rep_other)
        losses = 0.5 * (
            labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2)
        )
        return losses.mean() if self.size_average else losses.sum()
Rayyyyy's avatar
Rayyyyy committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

    @property
    def citation(self) -> str:
        return """
@inproceedings{hadsell2006dimensionality,
    author={Hadsell, R. and Chopra, S. and LeCun, Y.},
    booktitle={2006 IEEE Computer Society Conference on Computer Vision and Pattern Recognition (CVPR'06)}, 
    title={Dimensionality Reduction by Learning an Invariant Mapping}, 
    year={2006},
    volume={2},
    number={},
    pages={1735-1742},
    doi={10.1109/CVPR.2006.100}
}
"""