loss_overview.md 10.7 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# Loss Overview

Loss functions play a critical role in the performance of your fine-tuned model. Sadly, there is no "one size fits all" loss function. Ideally, this overview should help narrow down your choice of loss function(s) by matching them to your data formats.

**Note**: you can often convert one training data format into another, allowing more loss functions to be viable for your scenario. For example, `(sentence_A, sentence_B) pairs` with `class` labels can be converted into `(anchor, positive, negative) triplets` by sampling sentences with the same or different classes.

| Texts                                         | Labels                         | Appropriate Loss Functions                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       |
|-----------------------------------------------|--------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `single sentences`                            | `class`                        | <a href="../package_reference/losses.html#batchalltripletloss">`BatchAllTripletLoss`</a><br><a href="../package_reference/losses.html#batchhardsoftmargintripletloss">`BatchHardSoftMarginTripletLoss`</a><br><a href="../package_reference/losses.html#batchhardtripletloss">`BatchHardTripletLoss`</a><br><a href="../package_reference/losses.html#batchsemihardtripletloss">`BatchSemiHardTripletLoss`</a>                                                                                                                                   |
| `single sentences`                            | `none`                         | <a href="../package_reference/losses.html#contrastivetensionloss">`ContrastiveTensionLoss`</a><br><a href="../package_reference/losses.html#denoisingautoencoderloss">`DenoisingAutoEncoderLoss`</a>                                                                                                                                                                                                                                                                                                                                             |
| `(anchor, anchor) pairs`                      | `none`                         | <a href="../package_reference/losses.html#contrastivetensionlossinbatchnegatives">`ContrastiveTensionLossInBatchNegatives`</a>                                                                                                                                                                                                                                                                                                                                                                                                                   |
| `(damaged_sentence, original_sentence) pairs` | `none`                         | <a href="../package_reference/losses.html#denoisingautoencoderloss">`DenoisingAutoEncoderLoss`</a>                                                                                                                                                                                                                                                                                                                                                                                                                                               |
| `(sentence_A, sentence_B) pairs`              | `class`                        | <a href="../package_reference/losses.html#softmaxloss">`SoftmaxLoss`</a>                                                                                                                                                                                                                                                                                                                                                                                                                                                                         |
| `(anchor, positive) pairs`                    | `none`                         | <a href="../package_reference/losses.html#cachedmultiplenegativesrankingloss">`CachedMultipleNegativesRankingLoss`</a><br><a href="../package_reference/losses.html#multiplenegativesrankingloss">`MultipleNegativesRankingLoss`</a><br><a href="../package_reference/losses.html#multiplenegativessymmetricrankingloss">`MultipleNegativesSymmetricRankingLoss`</a><br><a href="../package_reference/losses.html#megabatchmarginloss">`MegaBatchMarginLoss`</a><br><a href="../package_reference/losses.html#gistembedloss">`GISTEmbedLoss`</a> |
| `(anchor, positive/negative) pairs`           | `1 if positive, 0 if negative` | <a href="../package_reference/losses.html#contrastiveloss">`ContrastiveLoss`</a><br><a href="../package_reference/losses.html#onlinecontrastiveloss">`OnlineContrastiveLoss`</a>                                                                                                                                                                                                                                                                                                                                                                 |
| `(sentence_A, sentence_B) pairs`              | `float similarity score`       | <a href="../package_reference/losses.html#cosentloss">`CoSENTLoss`</a><br><a href="../package_reference/losses.html#angleloss">`AnglELoss`</a><br><a href="../package_reference/losses.html#cosinesimilarityloss">`CosineSimilarityLoss`</a>                                                                                                                                                                                                                                                                                                     |
| `(anchor, positive, negative) triplets`       | `none`                         | <a href="../package_reference/losses.html#cachedmultiplenegativesrankingloss">`CachedMultipleNegativesRankingLoss`</a><br><a href="../package_reference/losses.html#multiplenegativesrankingloss">`MultipleNegativesRankingLoss`</a><br><a href="../package_reference/losses.html#tripletloss">`TripletLoss`</a><br><a href="../package_reference/losses.html#gistembedloss">`GISTEmbedLoss`</a>                                                                                                                                                 |

## Loss modifiers

These loss functions can be seen as *loss modifiers*: they work on top of standard loss functions, but apply those loss functions in different ways to try and instil useful properties into the trained embedding model.

For example, models trained with <a href="../package_reference/losses.html#matryoshkaloss">`MatryoshkaLoss`</a> produce embeddings whose size can be truncated without notable losses in performance, and models trained with <a href="../package_reference/losses.html#adaptivelayerloss">`AdaptiveLayerLoss`</a> still perform well when you remove model layers for faster inference.

| Texts | Labels | Appropriate Loss Functions                                                                                                                                                                                                                                   |
|-------|--------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| `any` | `any`  | <a href="../package_reference/losses.html#matryoshkaloss">`MatryoshkaLoss`</a><br><a href="../package_reference/losses.html#adaptivelayerloss">`AdaptiveLayerLoss`</a><br><a href="../package_reference/losses.html#matryoshka2dloss">`Matryoshka2dLoss`</a> |


## Distillation
These loss functions are specifically designed to be used when distilling the knowledge from one model into another.
For example, when finetuning a small model to behave more like a larger & stronger one, or when finetuning a model to become multi-lingual.

| Texts                                        | Labels                                                        | Appropriate Loss Functions                                                   |
|----------------------------------------------|---------------------------------------------------------------|------------------------------------------------------------------------------|
| `single sentences`                           | `model sentence embeddings`                                   | <a href="../package_reference/losses.html#mseloss">`MSELoss`</a>             |
| `(query, passage_one, passage_two) triplets` | `gold_sim(query, passage_one) - gold_sim(query, passage_two)` | <a href="../package_reference/losses.html#marginmseloss">`MarginMSELoss`</a> |

## Commonly used Loss Functions
In practice, not all loss functions get used equally often. The most common scenarios are:

* `(anchor, positive) pairs` without any labels: <a href="../package_reference/losses.html#multiplenegativesrankingloss"><code>MultipleNegativesRankingLoss</code></a> is commonly used to train the top performing embedding models. This data is often relatively cheap to obtain, and the models are generally very performant. <a href="../package_reference/losses.html#cachedmultiplenegativesrankingloss"><code>CachedMultipleNegativesRankingLoss</code></a> is often used to increase the batch size, resulting in superior performance.
* `(sentence_A, sentence_B) pairs` with a `float similarity score`: <a href="../package_reference/losses.html#cosinesimilarityloss"><code>CosineSimilarityLoss</code></a> is traditionally used a lot, though more recently <a href="../package_reference/losses.html#cosentloss"><code>CoSENTLoss</code></a> and <a href="../package_reference/losses.html#angleloss"><code>AnglELoss</code></a> are used as drop-in replacements with superior performance.