OnlineContrastiveLoss.py 3.78 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
from typing import Dict, Iterable

Rayyyyy's avatar
Rayyyyy committed
3
import torch.nn.functional as F
Rayyyyy's avatar
Rayyyyy committed
4
5
from torch import Tensor, nn

Rayyyyy's avatar
Rayyyyy committed
6
7
from sentence_transformers.SentenceTransformer import SentenceTransformer

Rayyyyy's avatar
Rayyyyy committed
8
9
from .ContrastiveLoss import SiameseDistanceMetric

Rayyyyy's avatar
Rayyyyy committed
10
11
12
13
14
15
16
17
18
19

class OnlineContrastiveLoss(nn.Module):
    def __init__(
        self, model: SentenceTransformer, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5
    ):
        """
        This Online Contrastive loss is similar to :class:`ConstrativeLoss`, but it selects hard positive (positives that
        are far apart) and hard negative pairs (negatives that are close) and computes the loss only for these pairs.
        This loss often yields better performances than ContrastiveLoss.

Rayyyyy's avatar
Rayyyyy committed
20
21
22
23
24
25
26
        Args:
            model: SentenceTransformer model
            distance_metric: Function that returns a distance between
                two embeddings. The class SiameseDistanceMetric contains
                pre-defined metrics that can be used
            margin: Negative samples (label == 0) should have a distance
                of at least the margin value.
Rayyyyy's avatar
Rayyyyy committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

        References:
            - `Training Examples > Quora Duplicate Questions <../../examples/training/quora_duplicate_questions/README.html>`_

        Requirements:
            1. (anchor, positive/negative) pairs
            2. Data should include hard positives and hard negatives

        Relations:
            - :class:`ContrastiveLoss` is similar, but does not use hard positive and hard negative pairs.
            :class:`OnlineContrastiveLoss` often yields better results.

        Inputs:
            +-----------------------------------------------+------------------------------+
            | Texts                                         | Labels                       |
            +===============================================+==============================+
            | (anchor, positive/negative) pairs             | 1 if positive, 0 if negative |
            +-----------------------------------------------+------------------------------+

        Example:
            ::

Rayyyyy's avatar
Rayyyyy committed
49
50
                from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, losses
                from datasets import Dataset
Rayyyyy's avatar
Rayyyyy committed
51

Rayyyyy's avatar
Rayyyyy committed
52
53
54
55
56
57
58
                model = SentenceTransformer("microsoft/mpnet-base")
                train_dataset = Dataset.from_dict({
                    "sentence1": ["It's nice weather outside today.", "He drove to work."],
                    "sentence2": ["It's so sunny.", "She walked to the store."],
                    "label": [1, 0],
                })
                loss = losses.OnlineContrastiveLoss(model)
Rayyyyy's avatar
Rayyyyy committed
59

Rayyyyy's avatar
Rayyyyy committed
60
61
62
63
                trainer = SentenceTransformerTrainer(
                    model=model,
                    train_dataset=train_dataset,
                    loss=loss,
Rayyyyy's avatar
Rayyyyy committed
64
                )
Rayyyyy's avatar
Rayyyyy committed
65
                trainer.train()
Rayyyyy's avatar
Rayyyyy committed
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
        """
        super(OnlineContrastiveLoss, self).__init__()
        self.model = model
        self.margin = margin
        self.distance_metric = distance_metric

    def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor, size_average=False):
        embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]

        distance_matrix = self.distance_metric(embeddings[0], embeddings[1])
        negs = distance_matrix[labels == 0]
        poss = distance_matrix[labels == 1]

        # select hard positive and hard negative pairs
        negative_pairs = negs[negs < (poss.max() if len(poss) > 1 else negs.mean())]
        positive_pairs = poss[poss > (negs.min() if len(negs) > 1 else poss.mean())]

        positive_loss = positive_pairs.pow(2).sum()
        negative_loss = F.relu(self.margin - negative_pairs).pow(2).sum()
        loss = positive_loss + negative_loss
        return loss