Commit 0fccd232 authored by Rayyyyy's avatar Rayyyyy
Browse files

First add

parents
Pipeline #1027 failed with stages
in 0 seconds
from enum import Enum
class SimilarityFunction(Enum):
COSINE = 0
EUCLIDEAN = 1
MANHATTAN = 2
DOT_PRODUCT = 3
from . import SentenceEvaluator
import logging
from ..util import pytorch_cos_sim
import os
import csv
import numpy as np
from typing import List
import torch
logger = logging.getLogger(__name__)
class TranslationEvaluator(SentenceEvaluator):
"""
Given two sets of sentences in different languages, e.g. (en_1, en_2, en_3...) and (fr_1, fr_2, fr_3, ...),
and assuming that fr_i is the translation of en_i.
Checks if vec(en_i) has the highest similarity to vec(fr_i). Computes the accuracy in both directions
"""
def __init__(
self,
source_sentences: List[str],
target_sentences: List[str],
show_progress_bar: bool = False,
batch_size: int = 16,
name: str = "",
print_wrong_matches: bool = False,
write_csv: bool = True,
):
"""
Constructs an evaluator based for the dataset
The labels need to indicate the similarity between the sentences.
:param source_sentences:
List of sentences in source language
:param target_sentences:
List of sentences in target language
:param print_wrong_matches:
Prints incorrect matches
:param write_csv:
Write results to CSV file
"""
self.source_sentences = source_sentences
self.target_sentences = target_sentences
self.name = name
self.batch_size = batch_size
self.show_progress_bar = show_progress_bar
self.print_wrong_matches = print_wrong_matches
assert len(self.source_sentences) == len(self.target_sentences)
if name:
name = "_" + name
self.csv_file = "translation_evaluation" + name + "_results.csv"
self.csv_headers = ["epoch", "steps", "src2trg", "trg2src"]
self.write_csv = write_csv
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = " after epoch {}:".format(epoch)
else:
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
else:
out_txt = ":"
logger.info("Evaluating translation matching Accuracy on " + self.name + " dataset" + out_txt)
embeddings1 = torch.stack(
model.encode(
self.source_sentences,
show_progress_bar=self.show_progress_bar,
batch_size=self.batch_size,
convert_to_numpy=False,
)
)
embeddings2 = torch.stack(
model.encode(
self.target_sentences,
show_progress_bar=self.show_progress_bar,
batch_size=self.batch_size,
convert_to_numpy=False,
)
)
cos_sims = pytorch_cos_sim(embeddings1, embeddings2).detach().cpu().numpy()
correct_src2trg = 0
correct_trg2src = 0
for i in range(len(cos_sims)):
max_idx = np.argmax(cos_sims[i])
if i == max_idx:
correct_src2trg += 1
elif self.print_wrong_matches:
print("i:", i, "j:", max_idx, "INCORRECT" if i != max_idx else "CORRECT")
print("Src:", self.source_sentences[i])
print("Trg:", self.target_sentences[max_idx])
print("Argmax score:", cos_sims[i][max_idx], "vs. correct score:", cos_sims[i][i])
results = zip(range(len(cos_sims[i])), cos_sims[i])
results = sorted(results, key=lambda x: x[1], reverse=True)
for idx, score in results[0:5]:
print("\t", idx, "(Score: %.4f)" % (score), self.target_sentences[idx])
cos_sims = cos_sims.T
for i in range(len(cos_sims)):
max_idx = np.argmax(cos_sims[i])
if i == max_idx:
correct_trg2src += 1
acc_src2trg = correct_src2trg / len(cos_sims)
acc_trg2src = correct_trg2src / len(cos_sims)
logger.info("Accuracy src2trg: {:.2f}".format(acc_src2trg * 100))
logger.info("Accuracy trg2src: {:.2f}".format(acc_trg2src * 100))
if output_path is not None and self.write_csv:
csv_path = os.path.join(output_path, self.csv_file)
output_file_exists = os.path.isfile(csv_path)
with open(csv_path, newline="", mode="a" if output_file_exists else "w", encoding="utf-8") as f:
writer = csv.writer(f)
if not output_file_exists:
writer.writerow(self.csv_headers)
writer.writerow([epoch, steps, acc_src2trg, acc_trg2src])
return (acc_src2trg + acc_trg2src) / 2
from . import SentenceEvaluator, SimilarityFunction
import logging
import os
import csv
from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances
from typing import List
from ..readers import InputExample
logger = logging.getLogger(__name__)
class TripletEvaluator(SentenceEvaluator):
"""
Evaluate a model based on a triplet: (sentence, positive_example, negative_example).
Checks if distance(sentence, positive_example) < distance(sentence, negative_example).
"""
def __init__(
self,
anchors: List[str],
positives: List[str],
negatives: List[str],
main_distance_function: SimilarityFunction = None,
name: str = "",
batch_size: int = 16,
show_progress_bar: bool = False,
write_csv: bool = True,
):
"""
:param anchors: Sentences to check similarity to. (e.g. a query)
:param positives: List of positive sentences
:param negatives: List of negative sentences
:param main_distance_function: One of 0 (Cosine), 1 (Euclidean) or 2 (Manhattan). Defaults to None, returning all 3.
:param name: Name for the output
:param batch_size: Batch size used to compute embeddings
:param show_progress_bar: If true, prints a progress bar
:param write_csv: Write results to a CSV file
"""
self.anchors = anchors
self.positives = positives
self.negatives = negatives
self.name = name
assert len(self.anchors) == len(self.positives)
assert len(self.anchors) == len(self.negatives)
self.main_distance_function = main_distance_function
self.batch_size = batch_size
if show_progress_bar is None:
show_progress_bar = (
logger.getEffectiveLevel() == logging.INFO or logger.getEffectiveLevel() == logging.DEBUG
)
self.show_progress_bar = show_progress_bar
self.csv_file: str = "triplet_evaluation" + ("_" + name if name else "") + "_results.csv"
self.csv_headers = ["epoch", "steps", "accuracy_cosinus", "accuracy_manhattan", "accuracy_euclidean"]
self.write_csv = write_csv
@classmethod
def from_input_examples(cls, examples: List[InputExample], **kwargs):
anchors = []
positives = []
negatives = []
for example in examples:
anchors.append(example.texts[0])
positives.append(example.texts[1])
negatives.append(example.texts[2])
return cls(anchors, positives, negatives, **kwargs)
def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
if epoch != -1:
if steps == -1:
out_txt = " after epoch {}:".format(epoch)
else:
out_txt = " in epoch {} after {} steps:".format(epoch, steps)
else:
out_txt = ":"
logger.info("TripletEvaluator: Evaluating the model on " + self.name + " dataset" + out_txt)
num_triplets = 0
num_correct_cos_triplets, num_correct_manhattan_triplets, num_correct_euclidean_triplets = 0, 0, 0
embeddings_anchors = model.encode(
self.anchors, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True
)
embeddings_positives = model.encode(
self.positives, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True
)
embeddings_negatives = model.encode(
self.negatives, batch_size=self.batch_size, show_progress_bar=self.show_progress_bar, convert_to_numpy=True
)
# Cosine distance
pos_cos_distance = paired_cosine_distances(embeddings_anchors, embeddings_positives)
neg_cos_distances = paired_cosine_distances(embeddings_anchors, embeddings_negatives)
# Manhattan
pos_manhattan_distance = paired_manhattan_distances(embeddings_anchors, embeddings_positives)
neg_manhattan_distances = paired_manhattan_distances(embeddings_anchors, embeddings_negatives)
# Euclidean
pos_euclidean_distance = paired_euclidean_distances(embeddings_anchors, embeddings_positives)
neg_euclidean_distances = paired_euclidean_distances(embeddings_anchors, embeddings_negatives)
for idx in range(len(pos_cos_distance)):
num_triplets += 1
if pos_cos_distance[idx] < neg_cos_distances[idx]:
num_correct_cos_triplets += 1
if pos_manhattan_distance[idx] < neg_manhattan_distances[idx]:
num_correct_manhattan_triplets += 1
if pos_euclidean_distance[idx] < neg_euclidean_distances[idx]:
num_correct_euclidean_triplets += 1
accuracy_cos = num_correct_cos_triplets / num_triplets
accuracy_manhattan = num_correct_manhattan_triplets / num_triplets
accuracy_euclidean = num_correct_euclidean_triplets / num_triplets
logger.info("Accuracy Cosine Distance: \t{:.2f}".format(accuracy_cos * 100))
logger.info("Accuracy Manhattan Distance:\t{:.2f}".format(accuracy_manhattan * 100))
logger.info("Accuracy Euclidean Distance:\t{:.2f}\n".format(accuracy_euclidean * 100))
if output_path is not None and self.write_csv:
csv_path = os.path.join(output_path, self.csv_file)
if not os.path.isfile(csv_path):
with open(csv_path, newline="", mode="w", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(self.csv_headers)
writer.writerow([epoch, steps, accuracy_cos, accuracy_manhattan, accuracy_euclidean])
else:
with open(csv_path, newline="", mode="a", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([epoch, steps, accuracy_cos, accuracy_manhattan, accuracy_euclidean])
if self.main_distance_function == SimilarityFunction.COSINE:
return accuracy_cos
if self.main_distance_function == SimilarityFunction.MANHATTAN:
return accuracy_manhattan
if self.main_distance_function == SimilarityFunction.EUCLIDEAN:
return accuracy_euclidean
return max(accuracy_cos, accuracy_manhattan, accuracy_euclidean)
from .SentenceEvaluator import SentenceEvaluator
from .SimilarityFunction import SimilarityFunction
from .BinaryClassificationEvaluator import BinaryClassificationEvaluator
from .EmbeddingSimilarityEvaluator import EmbeddingSimilarityEvaluator
from .InformationRetrievalEvaluator import InformationRetrievalEvaluator
from .LabelAccuracyEvaluator import LabelAccuracyEvaluator
from .MSEEvaluator import MSEEvaluator
from .MSEEvaluatorFromDataFrame import MSEEvaluatorFromDataFrame
from .ParaphraseMiningEvaluator import ParaphraseMiningEvaluator
from .SequentialEvaluator import SequentialEvaluator
from .TranslationEvaluator import TranslationEvaluator
from .TripletEvaluator import TripletEvaluator
from .RerankingEvaluator import RerankingEvaluator
__all__ = [
"SentenceEvaluator",
"SimilarityFunction",
"BinaryClassificationEvaluator",
"EmbeddingSimilarityEvaluator",
"InformationRetrievalEvaluator",
"LabelAccuracyEvaluator",
"MSEEvaluator",
"MSEEvaluatorFromDataFrame",
"ParaphraseMiningEvaluator",
"SequentialEvaluator",
"TranslationEvaluator",
"TripletEvaluator",
"RerankingEvaluator",
]
import random
from typing import Any, Dict, Iterable, List, Tuple
import warnings
from torch import Tensor, nn
from torch.nn import functional as F
import torch
from sentence_transformers import SentenceTransformer
from sentence_transformers.losses.CachedMultipleNegativesRankingLoss import CachedMultipleNegativesRankingLoss
from sentence_transformers.models import Transformer
class TransformerDecorator:
"""
Decorator that caches the embeddings of all layers of the transformer.
When `layer_idx` is set, it returns the cached embeddings of that layer instead.
This is meant to override the forward function of the Transformer.
"""
def __init__(self, transformer: Transformer, original_forward):
self.transformer = transformer
self.original_forward = original_forward
self.embeddings: List[Tuple[Tensor]] = []
self.last_embeddings: List[Tensor] = []
self.features: List[Dict[str, Tensor]] = []
self.layer_idx = None
self.call_idx = 0
def set_layer_idx(self, layer_idx):
self.layer_idx = layer_idx
self.call_idx = 0
def get_layer_embeddings(self):
return torch.concat([embedding[self.layer_idx] for embedding in self.embeddings], dim=1)
def __call__(self, features):
if self.layer_idx is None:
output = self.call_grow_cache(features)
else:
output = self.call_use_cache(features)
self.call_idx += 1
return output
def call_grow_cache(self, features):
"""
Temporarily sets the output_hidden_states to True, runs the model, and then restores the original setting.
Use the all_layer_embeddings to get the embeddings of all layers.
"""
original_output_hidden_states = self.transformer.auto_model.config.output_hidden_states
self.transformer.auto_model.config.output_hidden_states = True
output = self.original_forward(features)
# We ignore the first layer, as it is the input embeddings
# and the last layer, as we already computed the loss over it
self.num_layers = len(output["all_layer_embeddings"]) - 1
self.embeddings.append(output["all_layer_embeddings"][1:-1])
self.last_embeddings.append(output["token_embeddings"])
self.features.append(
{key: value for key, value in output.items() if key not in ["all_layer_embeddings", "token_embeddings"]}
)
# Restore original setting
self.transformer.auto_model.config.output_hidden_states = original_output_hidden_states
if original_output_hidden_states:
del output["all_layer_embeddings"]
return output
def call_use_cache(self, features):
return {**self.features[self.call_idx], "token_embeddings": self.embeddings[self.call_idx][self.layer_idx]}
class ForwardDecorator:
"""
Decorator that caches the embeddings after all modules (e.g. pooling) of the model.
Required to get the embeddings after all modules for the KL-divergence loss.
This is meant to override the forward function of the SentenceTransformer.
"""
def __init__(self, fn):
self.fn = fn
self.embeddings = []
def __call__(self, features):
output = self.fn(features)
self.embeddings.append(output["sentence_embedding"])
return output
def get_embeddings(self):
embeddings = torch.concat(self.embeddings, dim=0)
self.embeddings = []
return embeddings
class AdaptiveLayerLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
loss: nn.Module,
n_layers_per_step: int = 1,
last_layer_weight: float = 1.0,
prior_layers_weight: float = 1.0,
kl_div_weight: float = 1.0,
kl_temperature: float = 0.3,
) -> None:
"""
The AdaptiveLayerLoss can be seen as a loss *modifier* that allows you to use other loss functions at non-final
layers of the Sentence Transformer model. This is useful for when you want to train a model where users have
the option to lower the number of layers used to improve their inference speed and memory usage.
:param model: SentenceTransformer model
:param loss: The loss function to be used, e.g. :class:`MultipleNegativesRankingLoss`, :class:`CoSENTLoss`, etc.
:param n_layers_per_step: The number of layers to use per step. If -1, then all layers are used. If > 0, then
a random sample of `n_layers_per_step` layers are used per step, separate from the final layer, which is
always used. The 2DMSE paper uses `n_layers_per_step=1`. The default value is 1.
:param last_layer_weight: The weight to use for the loss of the final layer. Increase this to focus more on the
performance when using all layers. The default value is 1.0.
:param prior_layers_weight: The weight to use for the loss of the prior layers. Increase this to focus more on
the performance when using fewer layers. The default value is 1.0.
:param kl_div_weight: The weight to use for the KL-divergence loss that is used to make the prior layers match
that of the last layer. Increase this to focus more on the performance when using fewer layers. The default
value is 1.0.
:param kl_temperature: The temperature to use for the KL-divergence loss. If 0, then the KL-divergence loss is
not used. The default value is 1.0.
References:
- The concept was inspired by the 2DMSE paper: https://arxiv.org/abs/2402.14776
- `Adaptive Layers <../../examples/training/adaptive_layer/README.html>`_
Requirements:
1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss`.
Relations:
- :class:`Matryoshka2dLoss` uses this loss in combination with :class:`MatryoshkaLoss` which allows for
output dimensionality reduction for faster downstream tasks (e.g. retrieval).
Input:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| any | any |
+---------------------------------------+--------+
Example:
::
from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('microsoft/mpnet-base')
train_examples = [
InputExample(texts=['Anchor 1', 'Positive 1']),
InputExample(texts=['Anchor 2', 'Positive 2']),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.MultipleNegativesRankingLoss(model=model)
train_loss = losses.AdaptiveLayerLoss(model, train_loss)
model.fit(
[(train_dataloader, train_loss)],
epochs=10,
)
"""
super().__init__()
self.model = model
self.loss = loss
self.n_layers_per_step = n_layers_per_step
self.last_layer_weight = last_layer_weight
self.prior_layers_weight = prior_layers_weight
self.kl_div_weight = kl_div_weight
self.kl_temperature = kl_temperature
assert isinstance(self.model[0], Transformer)
if isinstance(loss, CachedMultipleNegativesRankingLoss):
warnings.warn("MatryoshkaLoss is not compatible with CachedMultipleNegativesRankingLoss.", stacklevel=2)
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor) -> Tensor:
# Decorate the forward function of the transformer to cache the embeddings of all layers
original_transformer_forward = self.model[0].forward
transformer_decorator = TransformerDecorator(self.model[0], original_transformer_forward)
self.model[0].forward = transformer_decorator
# Decorate the forward function of the model to get the embeddings after all modules (e.g. pooling)
original_forward = self.model.forward
forward_decorator = ForwardDecorator(original_forward)
self.model.forward = forward_decorator
# Run the loss normally: i.e. the final layer, but 1) use the transformers decorator to cache
# the embeddings of all layers and 2) use the forward decorator to get the embeddings after all modules
# for the KL-divergence loss
loss = self.loss(sentence_features, labels) * self.last_layer_weight
if self.kl_temperature > 0:
final_embeddings = forward_decorator.get_embeddings()
final_embeddings = F.softmax(final_embeddings / self.kl_temperature, dim=-1)
num_layers = transformer_decorator.num_layers
layer_indices = range(num_layers - 1)
if self.n_layers_per_step > 0 and self.n_layers_per_step < num_layers - 1:
layer_indices = random.sample(layer_indices, self.n_layers_per_step)
# This loop is over `num_layer - 1` layers because we already computed the loss over the final layer
for layer_idx in layer_indices:
# Add regular loss for each layer by using the cached embeddings of that layer
transformer_decorator.set_layer_idx(layer_idx)
layer_loss = self.loss(sentence_features, labels)
loss = loss + layer_loss / (1 + layer_idx) / len(layer_indices) * self.prior_layers_weight
# and KL-divergence loss between the current layer and the final layer
# Note: we use "batchmean" reduction as that aligns with the mathematical definition
if self.kl_temperature > 0:
embeddings = forward_decorator.get_embeddings()
kl_div_loss = F.kl_div(
F.log_softmax(embeddings / self.kl_temperature, dim=-1),
final_embeddings,
reduction="batchmean",
)
loss = loss + kl_div_loss * self.kl_temperature * self.kl_div_weight
self.model[0].forward = original_transformer_forward
self.model.forward = original_forward
return loss
def get_config_dict(self) -> Dict[str, Any]:
return {
"loss": self.loss.__class__.__name__,
"n_layers_per_step": self.n_layers_per_step,
"last_layer_weight": self.last_layer_weight,
"prior_layers_weight": self.prior_layers_weight,
"kl_div_weight": self.kl_div_weight,
"kl_temperature": self.kl_temperature,
}
from sentence_transformers import losses, SentenceTransformer, util
class AnglELoss(losses.CoSENTLoss):
def __init__(self, model: SentenceTransformer, scale: float = 20.0):
"""
This class implements AnglE (Angle Optimized) loss.
This is a modification of :class:`CoSENTLoss`, designed to address the following issue:
The cosine function's gradient approaches 0 as the wave approaches the top or bottom of its form.
This can hinder the optimization process, so AnglE proposes to instead optimize the angle difference
in complex space in order to mitigate this effect.
It expects that each of the InputExamples consists of a pair of texts and a float valued label, representing
the expected similarity score between the pair.
It computes the following loss function:
``loss = logsum(1+exp(s(k,l)-s(i,j))+exp...)``, where ``(i,j)`` and ``(k,l)`` are any of the input pairs in the
batch such that the expected similarity of ``(i,j)`` is greater than ``(k,l)``. The summation is over all possible
pairs of input pairs in the batch that match this condition. This is the same as CoSENTLoss, with a different
similarity function.
:param model: SentenceTransformerModel
:param scale: Output of similarity function is multiplied by scale value. Represents the inverse temperature.
References:
- For further details, see: https://arxiv.org/abs/2309.12871v1
Requirements:
- Sentence pairs with corresponding similarity scores in range of the similarity function. Default is [-1,1].
Relations:
- :class:`CoSENTLoss` is AnglELoss with ``pairwise_cos_sim`` as the metric, rather than ``pairwise_angle_sim``.
- :class:`CosineSimilarityLoss` seems to produce a weaker training signal than ``CoSENTLoss`` or ``AnglELoss``.
Inputs:
+--------------------------------+------------------------+
| Texts | Labels |
+================================+========================+
| (sentence_A, sentence_B) pairs | float similarity score |
+--------------------------------+------------------------+
Example:
::
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample
model = SentenceTransformer('bert-base-uncased')
train_examples = [InputExample(texts=['My first sentence', 'My second sentence'], label=1.0),
InputExample(texts=['My third sentence', 'Unrelated sentence'], label=0.3)]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.AnglELoss(model=model)
"""
super().__init__(model, scale, similarity_fct=util.pairwise_angle_sim)
from torch import nn, Tensor
from typing import Iterable, Dict
from .BatchHardTripletLoss import BatchHardTripletLoss, BatchHardTripletLossDistanceFunction
from sentence_transformers.SentenceTransformer import SentenceTransformer
class BatchAllTripletLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
distance_metric=BatchHardTripletLossDistanceFunction.eucledian_distance,
margin: float = 5,
):
"""
BatchAllTripletLoss takes a batch with (sentence, label) pairs and computes the loss for all possible, valid
triplets, i.e., anchor and positive must have the same label, anchor and negative a different label. The labels
must be integers, with same label indicating sentences from the same class. Your train dataset
must contain at least 2 examples per label class.
:param model: SentenceTransformer model
:param distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrics that can be used.
:param margin: Negative samples should be at least margin further apart from the anchor than the positive.
References:
* Source: https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
* Paper: In Defense of the Triplet Loss for Person Re-Identification, https://arxiv.org/abs/1703.07737
* Blog post: https://omoindrot.github.io/triplet-loss
Requirements:
1. Each sentence must be labeled with a class.
2. Your dataset must contain at least 2 examples per labels class.
Relations:
* :class:`BatchHardTripletLoss` uses only the hardest positive and negative samples, rather than all possible, valid triplets.
* :class:`BatchHardSoftMarginTripletLoss` uses only the hardest positive and negative samples, rather than all possible, valid triplets.
Also, it does not require setting a margin.
* :class:`BatchSemiHardTripletLoss` uses only semi-hard triplets, valid triplets, rather than all possible, valid triplets.
Inputs:
+------------------+--------+
| Texts | Labels |
+==================+========+
| single sentences | class |
+------------------+--------+
Example:
::
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_examples = [
InputExample(texts=['Sentence from class 0'], label=0),
InputExample(texts=['Another sentence from class 0'], label=0),
InputExample(texts=['Sentence from class 1'], label=1),
InputExample(texts=['Sentence from class 2'], label=2),
]
train_batch_size = 2
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.BatchAllTripletLoss(model=model)
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(BatchAllTripletLoss, self).__init__()
self.sentence_embedder = model
self.triplet_margin = margin
self.distance_metric = distance_metric
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
rep = self.sentence_embedder(sentence_features[0])["sentence_embedding"]
return self.batch_all_triplet_loss(labels, rep)
def batch_all_triplet_loss(self, labels, embeddings):
"""Build the triplet loss over a batch of embeddings.
We generate all the valid triplets and average the loss over the positive ones.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
Label_Sentence_Triplet: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = self.distance_metric(embeddings)
anchor_positive_dist = pairwise_dist.unsqueeze(2)
anchor_negative_dist = pairwise_dist.unsqueeze(1)
# Compute a 3D tensor of size (batch_size, batch_size, batch_size)
# triplet_loss[i, j, k] will contain the triplet loss of anchor=i, positive=j, negative=k
# Uses broadcasting where the 1st argument has shape (batch_size, batch_size, 1)
# and the 2nd (batch_size, 1, batch_size)
triplet_loss = anchor_positive_dist - anchor_negative_dist + self.triplet_margin
# Put to zero the invalid triplets
# (where label(a) != label(p) or label(n) == label(a) or a == p)
mask = BatchHardTripletLoss.get_triplet_mask(labels)
triplet_loss = mask.float() * triplet_loss
# Remove negative losses (i.e. the easy triplets)
triplet_loss[triplet_loss < 0] = 0
# Count number of positive triplets (where triplet_loss > 0)
valid_triplets = triplet_loss[triplet_loss > 1e-16]
num_positive_triplets = valid_triplets.size(0)
# num_valid_triplets = mask.sum()
# fraction_positive_triplets = num_positive_triplets / (num_valid_triplets.float() + 1e-16)
# Get final mean triplet loss over the positive valid triplets
triplet_loss = triplet_loss.sum() / (num_positive_triplets + 1e-16)
return triplet_loss
import torch
from torch import Tensor
from typing import Iterable, Dict
from .BatchHardTripletLoss import BatchHardTripletLoss, BatchHardTripletLossDistanceFunction
from sentence_transformers.SentenceTransformer import SentenceTransformer
class BatchHardSoftMarginTripletLoss(BatchHardTripletLoss):
def __init__(
self, model: SentenceTransformer, distance_metric=BatchHardTripletLossDistanceFunction.eucledian_distance
):
"""
BatchHardSoftMarginTripletLoss takes a batch with (sentence, label) pairs and computes the loss for all possible, valid
triplets, i.e., anchor and positive must have the same label, anchor and negative a different label. The labels
must be integers, with same label indicating sentences from the same class. Your train dataset
must contain at least 2 examples per label class. This soft-margin variant does not require setting a margin.
:param model: SentenceTransformer model
:param distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrics that can be used.
Definitions:
:Easy triplets: Triplets which have a loss of 0 because
``distance(anchor, positive) + margin < distance(anchor, negative)``.
:Hard triplets: Triplets where the negative is closer to the anchor than the positive, i.e.,
``distance(anchor, negative) < distance(anchor, positive)``.
:Semi-hard triplets: Triplets where the negative is not closer to the anchor than the positive, but which
still have a positive loss, i.e., ``distance(anchor, positive) < distance(anchor, negative) + margin``.
References:
* Source: https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
* Paper: In Defense of the Triplet Loss for Person Re-Identification, https://arxiv.org/abs/1703.07737
* Blog post: https://omoindrot.github.io/triplet-loss
Requirements:
1. Each sentence must be labeled with a class.
2. Your dataset must contain at least 2 examples per labels class.
3. Your dataset should contain hard positives and negatives.
Relations:
* :class:`BatchHardTripletLoss` uses a user-specified margin, while this loss does not require setting a margin.
Inputs:
+------------------+--------+
| Texts | Labels |
+==================+========+
| single sentences | class |
+------------------+--------+
Example:
::
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_examples = [
InputExample(texts=['Sentence from class 0'], label=0),
InputExample(texts=['Another sentence from class 0'], label=0),
InputExample(texts=['Sentence from class 1'], label=1),
InputExample(texts=['Sentence from class 2'], label=2)
]
train_batch_size = 2
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.BatchHardSoftMarginTripletLoss(model=model)
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(BatchHardSoftMarginTripletLoss, self).__init__(model)
self.sentence_embedder = model
self.distance_metric = distance_metric
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
rep = self.sentence_embedder(sentence_features[0])["sentence_embedding"]
return self.batch_hard_triplet_soft_margin_loss(labels, rep)
# Hard Triplet Loss with Soft Margin
# Paper: In Defense of the Triplet Loss for Person Re-Identification, https://arxiv.org/abs/1703.07737
def batch_hard_triplet_soft_margin_loss(self, labels: Tensor, embeddings: Tensor) -> Tensor:
"""Build the triplet loss over a batch of embeddings.
For each anchor, we get the hardest positive and hardest negative to form a triplet.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
Label_Sentence_Triplet: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = self.distance_metric(embeddings)
# For each anchor, get the hardest positive
# First, we need to get a mask for every valid positive (they should have same label)
mask_anchor_positive = BatchHardTripletLoss.get_anchor_positive_triplet_mask(labels).float()
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
anchor_positive_dist = mask_anchor_positive * pairwise_dist
# shape (batch_size, 1)
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
# For each anchor, get the hardest negative
# First, we need to get a mask for every valid negative (they should have different labels)
mask_anchor_negative = BatchHardTripletLoss.get_anchor_negative_triplet_mask(labels).float()
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
# shape (batch_size,)
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss with soft margin
# tl = hardest_positive_dist - hardest_negative_dist + margin
# tl[tl < 0] = 0
tl = torch.log1p(torch.exp(hardest_positive_dist - hardest_negative_dist))
triplet_loss = tl.mean()
return triplet_loss
import torch
from torch import nn, Tensor
from typing import Iterable, Dict
from sentence_transformers import util
from sentence_transformers.SentenceTransformer import SentenceTransformer
class BatchHardTripletLossDistanceFunction:
"""
This class defines distance functions, that can be used with Batch[All/Hard/SemiHard]TripletLoss
"""
@staticmethod
def cosine_distance(embeddings):
"""
Compute the 2D matrix of cosine distances (1-cosine_similarity) between all embeddings.
"""
return 1 - util.pytorch_cos_sim(embeddings, embeddings)
@staticmethod
def eucledian_distance(embeddings, squared=False):
"""
Compute the 2D matrix of eucledian distances between all the embeddings.
Args:
embeddings: tensor of shape (batch_size, embed_dim)
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
pairwise_distances: tensor of shape (batch_size, batch_size)
"""
dot_product = torch.matmul(embeddings, embeddings.t())
# Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
# This also provides more numerical stability (the diagonal of the result will be exactly 0).
# shape (batch_size,)
square_norm = torch.diag(dot_product)
# Compute the pairwise distance matrix as we have:
# ||a - b||^2 = ||a||^2 - 2 <a, b> + ||b||^2
# shape (batch_size, batch_size)
distances = square_norm.unsqueeze(0) - 2.0 * dot_product + square_norm.unsqueeze(1)
# Because of computation errors, some distances might be negative so we put everything >= 0.0
distances[distances < 0] = 0
if not squared:
# Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
# we need to add a small epsilon where distances == 0.0
mask = distances.eq(0).float()
distances = distances + mask * 1e-16
distances = (1.0 - mask) * torch.sqrt(distances)
return distances
class BatchHardTripletLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
distance_metric=BatchHardTripletLossDistanceFunction.eucledian_distance,
margin: float = 5,
):
"""
BatchHardTripletLoss takes a batch with (sentence, label) pairs and computes the loss for all possible, valid
triplets, i.e., anchor and positive must have the same label, anchor and negative a different label. It then looks
for the hardest positive and the hardest negatives.
The labels must be integers, with same label indicating sentences from the same class. Your train dataset
must contain at least 2 examples per label class.
:param model: SentenceTransformer model
:param distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrics that can be used
:param margin: Negative samples should be at least margin further apart from the anchor than the positive.
Definitions:
:Easy triplets: Triplets which have a loss of 0 because
``distance(anchor, positive) + margin < distance(anchor, negative)``.
:Hard triplets: Triplets where the negative is closer to the anchor than the positive, i.e.,
``distance(anchor, negative) < distance(anchor, positive)``.
:Semi-hard triplets: Triplets where the negative is not closer to the anchor than the positive, but which
still have a positive loss, i.e., ``distance(anchor, positive) < distance(anchor, negative) + margin``.
References:
* Source: https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
* Paper: In Defense of the Triplet Loss for Person Re-Identification, https://arxiv.org/abs/1703.07737
* Blog post: https://omoindrot.github.io/triplet-loss
Requirements:
1. Each sentence must be labeled with a class.
2. Your dataset must contain at least 2 examples per labels class.
3. Your dataset should contain hard positives and negatives.
Inputs:
+------------------+--------+
| Texts | Labels |
+==================+========+
| single sentences | class |
+------------------+--------+
Relations:
* :class:`BatchAllTripletLoss` uses all possible, valid triplets, rather than only the hardest positive and negative samples.
* :class:`BatchSemiHardTripletLoss` uses only semi-hard triplets, valid triplets, rather than only the hardest positive and negative samples.
* :class:`BatchHardSoftMarginTripletLoss` does not require setting a margin, while this loss does.
Example:
::
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_examples = [
InputExample(texts=['Sentence from class 0'], label=0),
InputExample(texts=['Another sentence from class 0'], label=0),
InputExample(texts=['Sentence from class 1'], label=1),
InputExample(texts=['Sentence from class 2'], label=2)
]
train_batch_size = 2
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.BatchHardTripletLoss(model=model)
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(BatchHardTripletLoss, self).__init__()
self.sentence_embedder = model
self.triplet_margin = margin
self.distance_metric = distance_metric
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
rep = self.sentence_embedder(sentence_features[0])["sentence_embedding"]
return self.batch_hard_triplet_loss(labels, rep)
# Hard Triplet Loss
# Source: https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
# Paper: In Defense of the Triplet Loss for Person Re-Identification, https://arxiv.org/abs/1703.07737
# Blog post: https://omoindrot.github.io/triplet-loss
def batch_hard_triplet_loss(self, labels: Tensor, embeddings: Tensor) -> Tensor:
"""Build the triplet loss over a batch of embeddings.
For each anchor, we get the hardest positive and hardest negative to form a triplet.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
Label_Sentence_Triplet: scalar tensor containing the triplet loss
"""
# Get the pairwise distance matrix
pairwise_dist = self.distance_metric(embeddings)
# For each anchor, get the hardest positive
# First, we need to get a mask for every valid positive (they should have same label)
mask_anchor_positive = BatchHardTripletLoss.get_anchor_positive_triplet_mask(labels).float()
# We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
anchor_positive_dist = mask_anchor_positive * pairwise_dist
# shape (batch_size, 1)
hardest_positive_dist, _ = anchor_positive_dist.max(1, keepdim=True)
# For each anchor, get the hardest negative
# First, we need to get a mask for every valid negative (they should have different labels)
mask_anchor_negative = BatchHardTripletLoss.get_anchor_negative_triplet_mask(labels).float()
# We add the maximum value in each row to the invalid negatives (label(a) == label(n))
max_anchor_negative_dist, _ = pairwise_dist.max(1, keepdim=True)
anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)
# shape (batch_size,)
hardest_negative_dist, _ = anchor_negative_dist.min(1, keepdim=True)
# Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
tl = hardest_positive_dist - hardest_negative_dist + self.triplet_margin
tl[tl < 0] = 0
triplet_loss = tl.mean()
return triplet_loss
@staticmethod
def get_triplet_mask(labels):
"""Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
A triplet (i, j, k) is valid if:
- i, j, k are distinct
- labels[i] == labels[j] and labels[i] != labels[k]
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
"""
# Check that i, j and k are distinct
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
indices_not_equal = ~indices_equal
i_not_equal_j = indices_not_equal.unsqueeze(2)
i_not_equal_k = indices_not_equal.unsqueeze(1)
j_not_equal_k = indices_not_equal.unsqueeze(0)
distinct_indices = (i_not_equal_j & i_not_equal_k) & j_not_equal_k
label_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
i_equal_j = label_equal.unsqueeze(2)
i_equal_k = label_equal.unsqueeze(1)
valid_labels = ~i_equal_k & i_equal_j
return valid_labels & distinct_indices
@staticmethod
def get_anchor_positive_triplet_mask(labels):
"""Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
"""
# Check that i and j are distinct
indices_equal = torch.eye(labels.size(0), device=labels.device).bool()
indices_not_equal = ~indices_equal
# Check if labels[i] == labels[j]
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
labels_equal = labels.unsqueeze(0) == labels.unsqueeze(1)
return labels_equal & indices_not_equal
@staticmethod
def get_anchor_negative_triplet_mask(labels):
"""Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
Args:
labels: tf.int32 `Tensor` with shape [batch_size]
Returns:
mask: tf.bool `Tensor` with shape [batch_size, batch_size]
"""
# Check if labels[i] != labels[k]
# Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
return ~(labels.unsqueeze(0) == labels.unsqueeze(1))
import torch
from torch import nn, Tensor
from typing import Iterable, Dict
from .BatchHardTripletLoss import BatchHardTripletLossDistanceFunction
from sentence_transformers.SentenceTransformer import SentenceTransformer
class BatchSemiHardTripletLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
distance_metric=BatchHardTripletLossDistanceFunction.eucledian_distance,
margin: float = 5,
):
"""
BatchSemiHardTripletLoss takes a batch with (label, sentence) pairs and computes the loss for all possible, valid
triplets, i.e., anchor and positive must have the same label, anchor and negative a different label. It then looks
for the semi hard positives and negatives.
The labels must be integers, with same label indicating sentences from the same class. Your train dataset
must contain at least 2 examples per label class.
:param model: SentenceTransformer model
:param distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrics that can be used
:param margin: Negative samples should be at least margin further apart from the anchor than the positive.
Definitions:
:Easy triplets: Triplets which have a loss of 0 because
``distance(anchor, positive) + margin < distance(anchor, negative)``.
:Hard triplets: Triplets where the negative is closer to the anchor than the positive, i.e.,
``distance(anchor, negative) < distance(anchor, positive)``.
:Semi-hard triplets: Triplets where the negative is not closer to the anchor than the positive, but which
still have a positive loss, i.e., ``distance(anchor, positive) < distance(anchor, negative) + margin``.
References:
* Source: https://github.com/NegatioN/OnlineMiningTripletLoss/blob/master/online_triplet_loss/losses.py
* Paper: In Defense of the Triplet Loss for Person Re-Identification, https://arxiv.org/abs/1703.07737
* Blog post: https://omoindrot.github.io/triplet-loss
Requirements:
1. Each sentence must be labeled with a class.
2. Your dataset must contain at least 2 examples per labels class.
3. Your dataset should contain semi hard positives and negatives.
Relations:
* :class:`BatchHardTripletLoss` uses only the hardest positive and negative samples, rather than only semi hard positive and negatives.
* :class:`BatchAllTripletLoss` uses all possible, valid triplets, rather than only semi hard positive and negatives.
* :class:`BatchHardSoftMarginTripletLoss` uses only the hardest positive and negative samples, rather than only semi hard positive and negatives.
Also, it does not require setting a margin.
Inputs:
+------------------+--------+
| Texts | Labels |
+==================+========+
| single sentences | class |
+------------------+--------+
Example:
::
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_examples = [
InputExample(texts=['Sentence from class 0'], label=0),
InputExample(texts=['Another sentence from class 0'], label=0),
InputExample(texts=['Sentence from class 1'], label=1),
InputExample(texts=['Sentence from class 2'], label=2)
]
train_batch_size = 2
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.BatchSemiHardTripletLoss(model=model)
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(BatchSemiHardTripletLoss, self).__init__()
self.sentence_embedder = model
self.margin = margin
self.distance_metric = distance_metric
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
rep = self.sentence_embedder(sentence_features[0])["sentence_embedding"]
return self.batch_semi_hard_triplet_loss(labels, rep)
# Semi-Hard Triplet Loss
# Based on: https://github.com/tensorflow/addons/blob/master/tensorflow_addons/losses/triplet.py#L71
# Paper: FaceNet: A Unified Embedding for Face Recognition and Clustering: https://arxiv.org/pdf/1503.03832.pdf
def batch_semi_hard_triplet_loss(self, labels: Tensor, embeddings: Tensor) -> Tensor:
"""Build the triplet loss over a batch of embeddings.
We generate all the valid triplets and average the loss over the positive ones.
Args:
labels: labels of the batch, of size (batch_size,)
embeddings: tensor of shape (batch_size, embed_dim)
margin: margin for triplet loss
squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
If false, output is the pairwise euclidean distance matrix.
Returns:
Label_Sentence_Triplet: scalar tensor containing the triplet loss
"""
labels = labels.unsqueeze(1)
pdist_matrix = self.distance_metric(embeddings)
adjacency = labels == labels.t()
adjacency_not = ~adjacency
batch_size = torch.numel(labels)
pdist_matrix_tile = pdist_matrix.repeat([batch_size, 1])
mask = adjacency_not.repeat([batch_size, 1]) & (pdist_matrix_tile > torch.reshape(pdist_matrix.t(), [-1, 1]))
mask_final = torch.reshape(torch.sum(mask, 1, keepdims=True) > 0.0, [batch_size, batch_size])
mask_final = mask_final.t()
negatives_outside = torch.reshape(
BatchSemiHardTripletLoss._masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size]
)
negatives_outside = negatives_outside.t()
negatives_inside = BatchSemiHardTripletLoss._masked_maximum(pdist_matrix, adjacency_not)
negatives_inside = negatives_inside.repeat([1, batch_size])
semi_hard_negatives = torch.where(mask_final, negatives_outside, negatives_inside)
loss_mat = (pdist_matrix - semi_hard_negatives) + self.margin
mask_positives = adjacency.float().to(labels.device) - torch.eye(batch_size, device=labels.device)
mask_positives = mask_positives.to(labels.device)
num_positives = torch.sum(mask_positives)
triplet_loss = (
torch.sum(torch.max(loss_mat * mask_positives, torch.tensor([0.0], device=labels.device))) / num_positives
)
return triplet_loss
@staticmethod
def _masked_minimum(data, mask, dim=1):
axis_maximums, _ = data.max(dim, keepdims=True)
masked_minimums = (data - axis_maximums) * mask
masked_minimums, _ = masked_minimums.min(dim, keepdims=True)
masked_minimums += axis_maximums
return masked_minimums
@staticmethod
def _masked_maximum(data, mask, dim=1):
axis_minimums, _ = data.min(dim, keepdims=True)
masked_maximums = (data - axis_minimums) * mask
masked_maximums, _ = masked_maximums.max(dim, keepdims=True)
masked_maximums += axis_minimums
return masked_maximums
from __future__ import annotations
from contextlib import nullcontext
from functools import partial
import torch
from torch import nn, Tensor
from torch.utils.checkpoint import get_device_states, set_device_states
from typing import Iterable, Dict, Iterator, List, Optional, Tuple
from sentence_transformers import SentenceTransformer
from sentence_transformers import util
import tqdm
class RandContext:
"""
Random-state context manager class. Reference: https://github.com/luyug/GradCache.
This class will back up the pytorch's random state during initialization. Then when the context is activated,
the class will set up the random state with the backed-up one.
"""
def __init__(self, *tensors):
self.fwd_cpu_state = torch.get_rng_state()
self.fwd_gpu_devices, self.fwd_gpu_states = get_device_states(*tensors)
def __enter__(self):
self._fork = torch.random.fork_rng(devices=self.fwd_gpu_devices, enabled=True)
self._fork.__enter__()
torch.set_rng_state(self.fwd_cpu_state)
set_device_states(self.fwd_gpu_devices, self.fwd_gpu_states)
def __exit__(self, exc_type, exc_val, exc_tb):
self._fork.__exit__(exc_type, exc_val, exc_tb)
self._fork = None
def _backward_hook(
grad_output: Tensor,
sentence_features: Iterable[Dict[str, Tensor]],
loss_obj: CachedMultipleNegativesRankingLoss,
):
"""A backward hook to backpropagate the cached gradients mini-batch by mini-batch."""
assert loss_obj.cache is not None
assert loss_obj.random_states is not None
with torch.enable_grad():
for sentence_feature, grad, random_states in zip(sentence_features, loss_obj.cache, loss_obj.random_states):
for (reps_mb, _), grad_mb in zip(
loss_obj.embed_minibatch_iter(
sentence_feature=sentence_feature,
with_grad=True,
copy_random_state=False,
random_states=random_states,
),
grad,
):
surrogate = torch.dot(reps_mb.flatten(), grad_mb.flatten()) * grad_output
surrogate.backward()
class CachedMultipleNegativesRankingLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
scale: float = 20.0,
similarity_fct: callable[[Tensor, Tensor], Tensor] = util.cos_sim,
mini_batch_size: int = 32,
show_progress_bar: bool = False,
):
"""
Boosted version of MultipleNegativesRankingLoss (https://arxiv.org/pdf/1705.00652.pdf) by GradCache (https://arxiv.org/pdf/2101.06983.pdf).
Constrastive learning (here our MNRL loss) with in-batch negatives is usually hard to work with large batch sizes due to (GPU) memory limitation.
Even with batch-scaling methods like gradient-scaling, it cannot work either. This is because the in-batch negatives make the data points within
the same batch non-independent and thus the batch cannot be broke down into mini-batches. GradCache is a smart way to solve this problem.
It achieves the goal by dividing the computation into two stages of embedding and loss calculation, which both can be scaled by mini-batches.
As a result, memory of constant size (e.g. that works with batch size = 32) can now process much larger batches (e.g. 65536).
In detail:
(1) It first does a quick embedding step without gradients/computation graphs to get all the embeddings;
(2) Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings;
(3) A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain.
Notes: All steps are done with mini-batches. In the original implementation of GradCache, (2) is not done in mini-batches and
requires a lot memory when batch size large. One drawback is about the speed. GradCache will sacrifice around 20% computation time according to the paper.
:param model: SentenceTransformer model
:param scale: Output of similarity function is multiplied by scale value
:param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
References:
- Efficient Natural Language Response Suggestion for Smart Reply, Section 4.4: https://arxiv.org/pdf/1705.00652.pdf
- Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup: https://arxiv.org/pdf/2101.06983.pdf
Requirements:
1. (anchor, positive) pairs or (anchor, positive, negative pairs)
2. Should be used with large batch sizes for superior performance, but has slower training time than :class:`MultipleNegativesRankingLoss`
Relations:
- Equivalent to :class:`MultipleNegativesRankingLoss`, but with caching that allows for much higher batch sizes
(and thus better performance) without extra memory usage. This loss also trains roughly 2x to 2.4x slower than
:class:`MultipleNegativesRankingLoss`.
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| (anchor, positive) pairs | none |
+---------------------------------------+--------+
| (anchor, positive, negative) triplets | none |
+---------------------------------------+--------+
Example:
::
from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('distilbert-base-uncased')
train_examples = [
InputExample(texts=['Anchor 1', 'Positive 1']),
InputExample(texts=['Anchor 2', 'Positive 2']),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=1024) # Here we can try much larger batch sizes!
train_loss = losses.CachedMultipleNegativesRankingLoss(model=model, mini_batch_size = 32)
model.fit(
[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(CachedMultipleNegativesRankingLoss, self).__init__()
self.model = model
self.scale = scale
self.similarity_fct = similarity_fct
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.mini_batch_size = mini_batch_size
self.cache: Optional[List[List[Tensor]]] = None
self.random_states: Optional[List[List[RandContext]]] = None
self.show_progress_bar = show_progress_bar
def embed_minibatch(
self,
sentence_feature: Dict[str, Tensor],
begin: int,
end: int,
with_grad: bool,
copy_random_state: bool,
random_state: Optional[RandContext] = None,
) -> Tuple[Tensor, Optional[RandContext]]:
"""Do forward pass on a minibatch of the input features and return corresponding embeddings."""
grad_context = nullcontext if with_grad else torch.no_grad
random_state_context = nullcontext() if random_state is None else random_state
sentence_feature_minibatch = {k: v[begin:end] for k, v in sentence_feature.items()}
with random_state_context:
with grad_context():
random_state = RandContext(*sentence_feature_minibatch.values()) if copy_random_state else None
reps = self.model(sentence_feature_minibatch)["sentence_embedding"] # (mbsz, hdim)
return reps, random_state
def embed_minibatch_iter(
self,
sentence_feature: Dict[str, Tensor],
with_grad: bool,
copy_random_state: bool,
random_states: Optional[List[RandContext]] = None,
) -> Iterator[Tuple[Tensor, Optional[RandContext]]]:
"""Do forward pass on all the minibatches of the input features and yield corresponding embeddings."""
input_ids: Tensor = sentence_feature["input_ids"]
bsz, _ = input_ids.shape
for i, b in enumerate(
tqdm.trange(
0,
bsz,
self.mini_batch_size,
desc="Embed mini-batches",
disable=not self.show_progress_bar,
)
):
e = b + self.mini_batch_size
reps, random_state = self.embed_minibatch(
sentence_feature=sentence_feature,
begin=b,
end=e,
with_grad=with_grad,
copy_random_state=copy_random_state,
random_state=None if random_states is None else random_states[i],
)
yield reps, random_state # reps: (mbsz, hdim)
def calculate_loss_and_cache_gradients(self, reps: List[List[Tensor]]) -> Tensor:
"""Calculate the cross-entropy loss and cache the gradients wrt. the embeddings."""
embeddings_a = torch.cat(reps[0]) # (bsz, hdim)
embeddings_b = torch.cat([torch.cat(r) for r in reps[1:]]) # ((1 + nneg) * bsz, hdim)
batch_size = len(embeddings_a)
labels = torch.tensor(
range(batch_size), dtype=torch.long, device=embeddings_a.device
) # (bsz, (1 + nneg) * bsz) Example a[i] should match with b[i]
losses: List[torch.Tensor] = []
for b in tqdm.trange(
0,
batch_size,
self.mini_batch_size,
desc="Preparing caches",
disable=not self.show_progress_bar,
):
e = b + self.mini_batch_size
scores: Tensor = self.similarity_fct(embeddings_a[b:e], embeddings_b) * self.scale
loss_mbatch: torch.Tensor = self.cross_entropy_loss(scores, labels[b:e]) * len(scores) / batch_size
loss_mbatch.backward()
losses.append(loss_mbatch.detach())
loss = sum(losses).requires_grad_()
self.cache = [[r.grad for r in rs] for rs in reps] # e.g. 3 * bsz/mbsz * (mbsz, hdim)
return loss
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor) -> Tensor:
# Step (1): A quick embedding step without gradients/computation graphs to get all the embeddings
reps = []
self.random_states = [] # Copy random states to guarantee exact reproduction of the embeddings during the second forward pass, i.e. step (3)
for sentence_feature in sentence_features:
reps_mbs = []
random_state_mbs = []
for reps_mb, random_state in self.embed_minibatch_iter(
sentence_feature=sentence_feature,
with_grad=False,
copy_random_state=True,
):
reps_mbs.append(reps_mb.detach().requires_grad_())
random_state_mbs.append(random_state)
reps.append(reps_mbs)
self.random_states.append(random_state_mbs)
# Step (2): Calculate the loss, backward up to the embeddings and cache the gradients wrt. to the embeddings
loss = self.calculate_loss_and_cache_gradients(reps)
# Step (3): A 2nd embedding step with gradients/computation graphs and connect the cached gradients into the backward chain
loss.register_hook(partial(_backward_hook, sentence_features=sentence_features, loss_obj=self))
return loss
def get_config_dict(self):
return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}
import torch
from torch import nn, Tensor
from typing import Iterable, Dict
from ..SentenceTransformer import SentenceTransformer
from .. import util
class CoSENTLoss(nn.Module):
def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.pairwise_cos_sim):
"""
This class implements CoSENT (Cosine Sentence) loss.
It expects that each of the InputExamples consists of a pair of texts and a float valued label, representing
the expected similarity score between the pair.
It computes the following loss function:
``loss = logsum(1+exp(s(k,l)-s(i,j))+exp...)``, where ``(i,j)`` and ``(k,l)`` are any of the input pairs in the
batch such that the expected similarity of ``(i,j)`` is greater than ``(k,l)``. The summation is over all possible
pairs of input pairs in the batch that match this condition.
Anecdotal experiments show that this loss function produces a more powerful training signal than :class:`CosineSimilarityLoss`,
resulting in faster convergence and a final model with superior performance. Consequently, CoSENTLoss may be used
as a drop-in replacement for :class:`CosineSimilarityLoss` in any training script.
:param model: SentenceTransformerModel
:param similarity_fct: Function to compute the PAIRWISE similarity between embeddings. Default is ``util.pairwise_cos_sim``.
:param scale: Output of similarity function is multiplied by scale value. Represents the inverse temperature.
References:
- For further details, see: https://kexue.fm/archives/8847
Requirements:
- Sentence pairs with corresponding similarity scores in range of the similarity function. Default is [-1,1].
Relations:
- :class:`AnglELoss` is CoSENTLoss with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``.
- :class:`CosineSimilarityLoss` seems to produce a weaker training signal than CoSENTLoss. In our experiments, CoSENTLoss is recommended.
Inputs:
+--------------------------------+------------------------+
| Texts | Labels |
+================================+========================+
| (sentence_A, sentence_B) pairs | float similarity score |
+--------------------------------+------------------------+
Example:
::
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample
model = SentenceTransformer('bert-base-uncased')
train_examples = [InputExample(texts=['My first sentence', 'My second sentence'], label=1.0),
InputExample(texts=['My third sentence', 'Unrelated sentence'], label=0.3)]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CoSENTLoss(model=model)
"""
super(CoSENTLoss, self).__init__()
self.model = model
self.similarity_fct = similarity_fct
self.scale = scale
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
scores = self.similarity_fct(embeddings[0], embeddings[1])
scores = scores * self.scale
scores = scores[:, None] - scores[None, :]
# label matrix indicating which pairs are relevant
labels = labels[:, None] < labels[None, :]
labels = labels.float()
# mask out irrelevant pairs so they are negligible after exp()
scores = scores - (1 - labels) * 1e12
# append a zero as e^0 = 1
scores = torch.cat((torch.zeros(1).to(scores.device), scores.view(-1)), dim=0)
loss = torch.logsumexp(scores, dim=0)
return loss
def get_config_dict(self):
return {"scale": self.scale, "similarity_fct": self.similarity_fct.__name__}
from enum import Enum
from typing import Iterable, Dict
import torch.nn.functional as F
from torch import nn, Tensor
from sentence_transformers.SentenceTransformer import SentenceTransformer
class SiameseDistanceMetric(Enum):
"""
The metric for the contrastive loss
"""
EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
COSINE_DISTANCE = lambda x, y: 1 - F.cosine_similarity(x, y)
class ContrastiveLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
distance_metric=SiameseDistanceMetric.COSINE_DISTANCE,
margin: float = 0.5,
size_average: bool = True,
):
"""
Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
:param model: SentenceTransformer model
:param distance_metric: Function that returns a distance between two embeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
:param margin: Negative samples (label == 0) should have a distance of at least the margin value.
:param size_average: Average by the size of the mini-batch.
References:
* Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
* `Training Examples > Quora Duplicate Questions <../../examples/training/quora_duplicate_questions/README.html>`_
Requirements:
1. (anchor, positive/negative) pairs
Relations:
- :class:`OnlineContrastiveLoss` is similar, but uses hard positive and hard negative pairs.
It often yields better results.
Inputs:
+-----------------------------------------------+------------------------------+
| Texts | Labels |
+===============================================+==============================+
| (anchor, positive/negative) pairs | 1 if positive, 0 if negative |
+-----------------------------------------------+------------------------------+
Example:
::
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.readers import InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('all-MiniLM-L6-v2')
train_examples = [
InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
train_loss = losses.ContrastiveLoss(model=model)
model.fit(
[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(ContrastiveLoss, self).__init__()
self.distance_metric = distance_metric
self.margin = margin
self.model = model
self.size_average = size_average
def get_config_dict(self):
distance_metric_name = self.distance_metric.__name__
for name, value in vars(SiameseDistanceMetric).items():
if value == self.distance_metric:
distance_metric_name = "SiameseDistanceMetric.{}".format(name)
break
return {"distance_metric": distance_metric_name, "margin": self.margin, "size_average": self.size_average}
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
assert len(reps) == 2
rep_anchor, rep_other = reps
distances = self.distance_metric(rep_anchor, rep_other)
losses = 0.5 * (
labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2)
)
return losses.mean() if self.size_average else losses.sum()
import torch
from torch import nn, Tensor
from typing import Iterable, Dict
from ..SentenceTransformer import SentenceTransformer
from .. import util
import copy
import random
import math
from .. import InputExample
import numpy as np
class ContrastiveTensionLoss(nn.Module):
"""
This loss expects only single sentences, without any labels. Positive and negative pairs are automatically created via random sampling,
such that a positive pair consists of two identical sentences and a negative pair consists of two different sentences. An independent
copy of the encoder model is created, which is used for encoding the first sentence of each pair. The original encoder model encodes the
second sentence. The embeddings are compared and scored using the generated labels (1 if positive, 0 if negative) using the binary cross
entropy objective.
Note that you must use the `ContrastiveTensionDataLoader` for this loss. The `pos_neg_ratio` of the ContrastiveTensionDataLoader can be
used to determine the number of negative pairs per positive pair.
Generally, :class:`ContrastiveTensionLossInBatchNegatives` is recommended over this loss, as it gives a stronger training signal.
:param model: SentenceTransformer model
References:
* Semantic Re-Tuning with Contrastive Tension: https://openreview.net/pdf?id=Ov_sMNau-PF
* `Unsupervised Learning > CT <../../examples/unsupervised_learning/CT/README.html>`_
Relations:
* :class:`ContrastiveTensionLossInBatchNegatives` uses in-batch negative sampling, which gives a stronger training signal than this loss.
Inputs:
+------------------+--------+
| Texts | Labels |
+==================+========+
| single sentences | none |
+------------------+--------+
Example:
::
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.losses import ContrastiveTensionDataLoader
model = SentenceTransformer('all-MiniLM-L6-v2')
train_examples = [
'This is the 1st sentence',
'This is the 2nd sentence',
'This is the 3rd sentence',
'This is the 4th sentence',
'This is the 5th sentence',
'This is the 6th sentence',
'This is the 7th sentence',
'This is the 8th sentence',
'This is the 9th sentence',
'This is the final sentence',
]
train_dataloader = ContrastiveTensionDataLoader(train_examples, batch_size=3, pos_neg_ratio=3)
train_loss = losses.ContrastiveTensionLoss(model=model)
model.fit(
[(train_dataloader, train_loss)],
epochs=10,
)
"""
def __init__(self, model: SentenceTransformer):
super(ContrastiveTensionLoss, self).__init__()
self.model2 = model # This will be the final model used during the inference time.
self.model1 = copy.deepcopy(model)
self.criterion = nn.BCEWithLogitsLoss(reduction="sum")
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
sentence_features1, sentence_features2 = tuple(sentence_features)
reps_1 = self.model1(sentence_features1)["sentence_embedding"] # (bsz, hdim)
reps_2 = self.model2(sentence_features2)["sentence_embedding"]
sim_scores = (
torch.matmul(reps_1[:, None], reps_2[:, :, None]).squeeze(-1).squeeze(-1)
) # (bsz,) dot product, i.e. S1S2^T
loss = self.criterion(sim_scores, labels.type_as(sim_scores))
return loss
class ContrastiveTensionLossInBatchNegatives(nn.Module):
def __init__(self, model: SentenceTransformer, scale: float = 20.0, similarity_fct=util.cos_sim):
"""
This loss expects only single sentences, without any labels. Positive and negative pairs are automatically created via random sampling,
such that a positive pair consists of two identical sentences and a negative pair consists of two different sentences. An independent
copy of the encoder model is created, which is used for encoding the first sentence of each pair. The original encoder model encodes the
second sentence. Unlike :class:`ContrastiveTensionLoss`, this loss uses the batch negative sampling strategy, i.e. the negative pairs
are sampled from the batch. Using in-batch negative sampling gives a stronger training signal than the original :class:`ContrastiveTensionLoss`.
The performance usually increases with increasing batch sizes.
Note that you should not use the `ContrastiveTensionDataLoader` for this loss, but just a normal DataLoader with `InputExample` instances.
The two texts of each `InputExample` instance should be identical.
:param model: SentenceTransformer model
:param scale: Output of similarity function is multiplied by scale value
:param similarity_fct: similarity function between sentence embeddings. By default, cos_sim. Can also be set to dot product (and then set scale to 1)
References:
- Semantic Re-Tuning with Contrastive Tension: https://openreview.net/pdf?id=Ov_sMNau-PF
- `Unsupervised Learning > CT (In-Batch Negatives) <../../examples/unsupervised_learning/CT_In-Batch_Negatives/README.html>`_
Relations:
* :class:`ContrastiveTensionLoss` does not select negative pairs in-batch, resulting in a weaker training signal than this loss.
Inputs:
+------------------------+--------+
| Texts | Labels |
+========================+========+
| (anchor, anchor) pairs | none |
+------------------------+--------+
Example:
::
from sentence_transformers import SentenceTransformer, losses
from torch.utils.data import DataLoader
model = SentenceTransformer('all-MiniLM-L6-v2')
train_examples = [
InputExample(texts=['This is a positive pair', 'Where the distance will be minimized'], label=1),
InputExample(texts=['This is a negative pair', 'Their distance will be increased'], label=0),
]
train_examples = [
InputExample(texts=['This is the 1st sentence', 'This is the 1st sentence']),
InputExample(texts=['This is the 2nd sentence', 'This is the 2nd sentence']),
InputExample(texts=['This is the 3rd sentence', 'This is the 3rd sentence']),
InputExample(texts=['This is the 4th sentence', 'This is the 4th sentence']),
InputExample(texts=['This is the 5th sentence', 'This is the 5th sentence']),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.ContrastiveTensionLossInBatchNegatives(model=model)
model.fit(
[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(ContrastiveTensionLossInBatchNegatives, self).__init__()
self.model2 = model # This will be the final model used during the inference time.
self.model1 = copy.deepcopy(model)
self.similarity_fct = similarity_fct
self.cross_entropy_loss = nn.CrossEntropyLoss()
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(scale))
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
sentence_features1, sentence_features2 = tuple(sentence_features)
embeddings_a = self.model1(sentence_features1)["sentence_embedding"] # (bsz, hdim)
embeddings_b = self.model2(sentence_features2)["sentence_embedding"]
scores = self.similarity_fct(embeddings_a, embeddings_b) * self.logit_scale.exp() # self.scale
labels = torch.tensor(range(len(scores)), dtype=torch.long, device=scores.device)
return (self.cross_entropy_loss(scores, labels) + self.cross_entropy_loss(scores.t(), labels)) / 2
################# CT Data Loader #################
# For CT, we need batches in a specific format
# In each batch, we have one positive pair (i.e. [sentA, sentA]) and 7 negative pairs (i.e. [sentA, sentB]).
# To achieve this, we create a custom DataLoader that produces batches with this property
class ContrastiveTensionDataLoader:
def __init__(self, sentences, batch_size, pos_neg_ratio=8):
self.sentences = sentences
self.batch_size = batch_size
self.pos_neg_ratio = pos_neg_ratio
self.collate_fn = None
if self.batch_size % self.pos_neg_ratio != 0:
raise ValueError(
f"ContrastiveTensionDataLoader was loaded with a pos_neg_ratio of {pos_neg_ratio} and a batch size of {batch_size}. The batch size must be divisible by the pos_neg_ratio"
)
def __iter__(self):
random.shuffle(self.sentences)
sentence_idx = 0
batch = []
while sentence_idx + 1 < len(self.sentences):
s1 = self.sentences[sentence_idx]
if len(batch) % self.pos_neg_ratio > 0: # Negative (different) pair
sentence_idx += 1
s2 = self.sentences[sentence_idx]
label = 0
else: # Positive (identical pair)
s2 = self.sentences[sentence_idx]
label = 1
sentence_idx += 1
batch.append(InputExample(texts=[s1, s2], label=label))
if len(batch) >= self.batch_size:
yield self.collate_fn(batch) if self.collate_fn is not None else batch
batch = []
def __len__(self):
return math.floor(len(self.sentences) / (2 * self.batch_size))
import torch
from torch import nn, Tensor
from typing import Iterable, Dict
from ..SentenceTransformer import SentenceTransformer
class CosineSimilarityLoss(nn.Module):
def __init__(self, model: SentenceTransformer, loss_fct=nn.MSELoss(), cos_score_transformation=nn.Identity()):
"""
CosineSimilarityLoss expects that the InputExamples consists of two texts and a float label. It computes the
vectors ``u = model(sentence_A)`` and ``v = model(sentence_B)`` and measures the cosine-similarity between the two.
By default, it minimizes the following loss: ``||input_label - cos_score_transformation(cosine_sim(u,v))||_2``.
:param model: SentenceTransformer model
:param loss_fct: Which pytorch loss function should be used to compare the ``cosine_similarity(u, v)`` with the input_label?
By default, MSE is used: ``||input_label - cosine_sim(u, v)||_2``
:param cos_score_transformation: The cos_score_transformation function is applied on top of cosine_similarity.
By default, the identify function is used (i.e. no change).
References:
- `Training Examples > Semantic Textual Similarity <../../examples/training/sts/README.html>`_
Requirements:
1. Sentence pairs with corresponding similarity scores in range `[0, 1]`
Relations:
- :class:`CoSENTLoss` seems to produce a stronger training signal than CosineSimilarityLoss. In our experiments, CoSENTLoss is recommended.
- :class:`AnglELoss` is :class:`CoSENTLoss` with ``pairwise_angle_sim`` as the metric, rather than ``pairwise_cos_sim``. It also produces a stronger training signal than CosineSimilarityLoss.
Inputs:
+--------------------------------+------------------------+
| Texts | Labels |
+================================+========================+
| (sentence_A, sentence_B) pairs | float similarity score |
+--------------------------------+------------------------+
Example:
::
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
train_examples = [
InputExample(texts=['My first sentence', 'My second sentence'], label=0.8),
InputExample(texts=['Another pair', 'Unrelated sentence'], label=0.3)
]
train_batch_size = 1
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)
model.fit(
[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(CosineSimilarityLoss, self).__init__()
self.model = model
self.loss_fct = loss_fct
self.cos_score_transformation = cos_score_transformation
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
output = self.cos_score_transformation(torch.cosine_similarity(embeddings[0], embeddings[1]))
return self.loss_fct(output, labels.view(-1))
from torch import nn, Tensor
from typing import Iterable, Dict
from sentence_transformers import SentenceTransformer
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, PreTrainedModel
import logging
logger = logging.getLogger(__name__)
class DenoisingAutoEncoderLoss(nn.Module):
def __init__(self, model: SentenceTransformer, decoder_name_or_path: str = None, tie_encoder_decoder: bool = True):
"""
This loss expects as input a pairs of damaged sentences and the corresponding original ones.
During training, the decoder reconstructs the original sentences from the encoded sentence embeddings.
Here the argument 'decoder_name_or_path' indicates the pretrained model (supported by Hugging Face) to be used as the decoder.
Since decoding process is included, here the decoder should have a class called XXXLMHead (in the context of Hugging Face's Transformers).
The 'tie_encoder_decoder' flag indicates whether to tie the trainable parameters of encoder and decoder,
which is shown beneficial to model performance while limiting the amount of required memory.
Only when the encoder and decoder are from the same architecture, can the flag 'tie_encoder_decoder' work.
The data generation process (i.e. the 'damaging' process) has already been implemented in ``DenoisingAutoEncoderDataset``,
allowing you to only provide regular sentences.
:param model: SentenceTransformer model
:param decoder_name_or_path: Model name or path for initializing a decoder (compatible with Huggingface's Transformers)
:param tie_encoder_decoder: whether to tie the trainable parameters of encoder and decoder
References:
* TSDAE paper: https://arxiv.org/pdf/2104.06979.pdf
* `Unsupervised Learning > TSDAE <../../examples/unsupervised_learning/TSDAE/README.html>`_
Requirements:
1. The decoder should have a class called XXXLMHead (in the context of Hugging Face's Transformers)
2. Should use a large corpus
Inputs:
+------------------------------------------------------+--------+
| Texts | Labels |
+======================================================+========+
| (damaged\_sentence, original\_sentence) pairs | none |
+------------------------------------------------------+--------+
| sentence fed through ``DenoisingAutoEncoderDataset`` | none |
+------------------------------------------------------+--------+
Example:
::
from sentence_transformers import SentenceTransformer, losses
from sentence_transformers.datasets import DenoisingAutoEncoderDataset
from torch.utils.data import DataLoader
model_name = "bert-base-cased"
model = SentenceTransformer(model_name)
train_sentences = [
"First training sentence", "Second training sentence", "Third training sentence", "Fourth training sentence",
]
batch_size = 2
train_dataset = DenoisingAutoEncoderDataset(train_sentences)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
train_loss = losses.DenoisingAutoEncoderLoss(
model, decoder_name_or_path=model_name, tie_encoder_decoder=True
)
model.fit(
train_objectives=[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(DenoisingAutoEncoderLoss, self).__init__()
self.encoder = model # This will be the final model used during the inference time.
self.tokenizer_encoder = model.tokenizer
encoder_name_or_path = model[0].auto_model.config._name_or_path
if decoder_name_or_path is None:
assert (
tie_encoder_decoder
), "Must indicate the decoder_name_or_path argument when tie_encoder_decoder=False!"
if tie_encoder_decoder:
if decoder_name_or_path:
logger.warning("When tie_encoder_decoder=True, the decoder_name_or_path will be invalid.")
decoder_name_or_path = encoder_name_or_path
self.tokenizer_decoder = AutoTokenizer.from_pretrained(decoder_name_or_path)
self.need_retokenization = not isinstance(self.tokenizer_encoder, type(self.tokenizer_decoder))
decoder_config = AutoConfig.from_pretrained(decoder_name_or_path)
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True
kwargs_decoder = {"config": decoder_config}
try:
self.decoder = AutoModelForCausalLM.from_pretrained(decoder_name_or_path, **kwargs_decoder)
except ValueError as e:
logger.error(
f'Model name or path "{decoder_name_or_path}" does not support being as a decoder. Please make sure the decoder model has an "XXXLMHead" class.'
)
raise e
assert model[0].auto_model.config.hidden_size == decoder_config.hidden_size, "Hidden sizes do not match!"
if self.tokenizer_decoder.pad_token is None:
# Needed by GPT-2, etc.
self.tokenizer_decoder.pad_token = self.tokenizer_decoder.eos_token
self.decoder.config.pad_token_id = self.decoder.config.eos_token_id
if len(AutoTokenizer.from_pretrained(encoder_name_or_path)) != len(self.tokenizer_encoder):
logger.warning(
"WARNING: The vocabulary of the encoder has been changed. One might need to change the decoder vocabulary, too."
)
if tie_encoder_decoder:
assert not self.need_retokenization, "The tokenizers should be the same when tie_encoder_decoder=True."
if len(self.tokenizer_encoder) != len(self.tokenizer_decoder): # The vocabulary has been changed.
self.tokenizer_decoder = self.tokenizer_encoder
self.decoder.resize_token_embeddings(len(self.tokenizer_decoder))
logger.warning(
"Since the encoder vocabulary has been changed and --tie_encoder_decoder=True, now the new vocabulary has also been used for the decoder."
)
decoder_base_model_prefix = self.decoder.base_model_prefix
PreTrainedModel._tie_encoder_decoder_weights(
model[0].auto_model, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
)
def retokenize(self, sentence_features):
input_ids = sentence_features["input_ids"]
device = input_ids.device
sentences_decoded = self.tokenizer_encoder.batch_decode(
input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
retokenized = self.tokenizer_decoder(
sentences_decoded, padding=True, truncation="longest_first", return_tensors="pt", max_length=None
).to(device)
return retokenized
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
source_features, target_features = tuple(sentence_features)
if self.need_retokenization:
# since the sentence_features here are all tokenized by encoder's tokenizer,
# retokenization by the decoder's one is needed if different tokenizers used
target_features = self.retokenize(target_features)
reps = self.encoder(source_features)["sentence_embedding"] # (bsz, hdim)
# Prepare input and output
target_length = target_features["input_ids"].shape[1]
decoder_input_ids = target_features["input_ids"].clone()[:, : target_length - 1]
label_ids = target_features["input_ids"][:, 1:]
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
inputs_embeds=None,
attention_mask=None,
encoder_hidden_states=reps[:, None], # (bsz, hdim) -> (bsz, 1, hdim)
encoder_attention_mask=source_features["attention_mask"][:, 0:1],
labels=None,
return_dict=None,
use_cache=False,
)
# Calculate loss
lm_logits = decoder_outputs[0]
ce_loss_fct = nn.CrossEntropyLoss(ignore_index=self.tokenizer_decoder.pad_token_id)
loss = ce_loss_fct(lm_logits.view(-1, lm_logits.shape[-1]), label_ids.reshape(-1))
return loss
from typing import Any, Iterable, Dict
import torch
from torch import nn, Tensor
from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.models import Transformer
class GISTEmbedLoss(nn.Module):
def __init__(
self,
model: SentenceTransformer,
guide: SentenceTransformer,
temperature: float = 0.01,
):
"""
This loss is used to train a SentenceTransformer model using the GISTEmbed algorithm.
It takes a model and a guide model as input, and uses the guide model to guide the
in-batch negative sample selection. The cosine similarity is used to compute the loss
and the temperature parameter is used to scale the cosine similarities.
:param model: SentenceTransformer model based on a `transformers` model.
:param guide: SentenceTransformer model to guide the in-batch negative sample selection.
:param temperature: Temperature parameter to scale the cosine similarities.
References:
- For further details, see: https://arxiv.org/abs/2402.16829
Requirements:
1. (anchor, positive, negative) triplets
2. (anchor, positive) pairs
Relations:
- :class:`MultipleNegativesRankingLoss` is similar to this loss, but it does not use
a guide model to guide the in-batch negative sample selection. `GISTEmbedLoss` yields
a stronger training signal at the cost of some training overhead.
Inputs:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| (anchor, positive, negative) triplets | none |
+---------------------------------------+--------+
| (anchor, positive) pairs | none |
+---------------------------------------+--------+
Example:
::
from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('all-MiniLM-L6-v2')
guide = SentenceTransformer('avsolatorio/GIST-small-Embedding-v0')
train_examples = [
InputExample(texts=['The first query', 'The first positive passage', 'The first negative passage']),
InputExample(texts=['The second query', 'The second positive passage', 'The second negative passage']),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=2)
train_loss = losses.GISTEmbedLoss(model=model, guide=guide)
model.fit(
[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(GISTEmbedLoss, self).__init__()
self.model = model
self.guide = guide
self.temperature = temperature
self.similarity_fct = nn.CosineSimilarity(dim=-1)
if not isinstance(model[0], Transformer) or not isinstance(guide[0], Transformer):
raise ValueError(
"Both the training model and the guiding model must be based on the `transformers` architecture."
)
self.must_retokenize = (
model.tokenizer.vocab != guide.tokenizer.vocab or guide.max_seq_length < model.max_seq_length
)
def sim_matrix(self, embed1, embed2):
return self.similarity_fct(embed1.unsqueeze(1), embed2.unsqueeze(0))
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
embeddings = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
with torch.no_grad():
if self.must_retokenize:
decoded = [
self.model.tokenizer.batch_decode(sentence_feature["input_ids"], skip_special_tokens=True)
for sentence_feature in sentence_features
]
sentence_features = [self.guide.tokenize(sentences) for sentences in decoded]
sentence_features = [
{key: value.to(self.guide.device) for key, value in sentence_feature.items()}
for sentence_feature in sentence_features
]
guide_embeddings = [
self.guide(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features
]
negative = None
negative_guide = None
if len(embeddings) == 2:
anchor, positive = embeddings
anchor_guide, positive_guide = guide_embeddings
elif len(embeddings) == 3:
anchor, positive, negative = embeddings
anchor_guide, positive_guide, negative_guide = guide_embeddings
else:
raise ValueError("Expected 2 or 3 embeddings, got {}".format(len(embeddings)))
# Compute the model's similarities
ap_sim = self.sim_matrix(anchor, positive)
aa_sim = self.sim_matrix(anchor, anchor)
pp_sim = self.sim_matrix(positive, positive)
# Let's compute the similarity matrices for the combinations of anchor and positive samples.
guided_ap_sim = self.sim_matrix(anchor_guide, positive_guide)
guided_aa_sim = self.sim_matrix(anchor_guide, anchor_guide)
guided_pp_sim = self.sim_matrix(positive_guide, positive_guide)
# Define the anchor threshold
guided_sim = guided_ap_sim.diagonal().view(-1, 1)
# Find which samples cannot be used as negatives because they are
# more similar to the query than the assigned positive as deemed by the guide model.
# For these samples, we mask them with -inf to basically ignore their contribution to
# the loss.
ap_sim[guided_ap_sim > guided_sim] = -torch.inf
aa_sim[guided_aa_sim > guided_sim] = -torch.inf
pp_sim[guided_pp_sim > guided_sim] = -torch.inf
scores = [ap_sim, aa_sim, pp_sim]
# Handle the case where we have a negative sample
if negative is not None:
an_sim = self.sim_matrix(anchor, negative)
guided_an_sim = self.sim_matrix(anchor_guide, negative_guide)
an_sim[guided_an_sim > guided_sim] = -torch.inf
scores.append(an_sim)
scores = torch.cat(scores, dim=1) / self.temperature
# NOTE: We use arange here since the ap_sim matrix contains the anchor-positive
# similarities along the diagonal.
labels = torch.arange(scores.size(0)).long().to(scores.device)
return nn.CrossEntropyLoss()(scores, labels)
def get_config_dict(self) -> Dict[str, Any]:
return {
"guide": self.guide,
"temperature": self.temperature,
}
from torch import nn, Tensor
from typing import Iterable, Dict
class MSELoss(nn.Module):
def __init__(self, model):
"""
Computes the MSE loss between the computed sentence embedding and a target sentence embedding. This loss
is used when extending sentence embeddings to new languages as described in our publication
Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation.
For an example, see `the distillation documentation <../../examples/training/distillation/README.html>`_ on extending language models to new languages.
:param model: SentenceTransformerModel
References:
- Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation: https://arxiv.org/abs/2004.09813
- `Training > Model Distillation <../../examples/training/distillation/README.html>`_
- `Training > Multilingual Models <../../examples/training/multilingual/README.html>`_
Requirements:
1. Usually uses a finetuned teacher M in a knowledge distillation setup
Relations:
- :class:`MarginMSELoss` is equivalent to this loss, but with a margin through a negative pair.
Input:
+-------------------+-----------------------------+
| Texts | Labels |
+===================+=============================+
| single sentences | model sentence embeddings |
+-------------------+-----------------------------+
Example::
from sentence_transformers import SentenceTransformer, InputExample, losses
from torch.utils.data import DataLoader
model_en = SentenceTransformer('bert-base-cased')
model_fr = SentenceTransformer('flaubert/flaubert_base_cased')
examples_en = ['The first sentence', 'The second sentence', 'The third sentence', 'The fourth sentence']
examples_fr = ['La première phrase', 'La deuxième phrase', 'La troisième phrase', 'La quatrième phrase']
train_batch_size = 2
labels_en_en = model_en.encode(examples_en)
examples_en_fr = [InputExample(texts=[x], label=labels_en_en[i]) for i, x in enumerate(examples_en)]
loader_en_fr = DataLoader(examples_en_fr, batch_size=train_batch_size)
examples_fr_fr = [InputExample(texts=[x], label=labels_en_en[i]) for i, x in enumerate(examples_fr)]
loader_fr_fr = DataLoader(examples_fr_fr, batch_size=train_batch_size)
train_loss = losses.MSELoss(model=model_fr)
model_fr.fit(
[(loader_en_fr, train_loss), (loader_fr_fr, train_loss)],
epochs=10,
)
"""
super(MSELoss, self).__init__()
self.model = model
self.loss_fct = nn.MSELoss()
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
rep = self.model(sentence_features[0])["sentence_embedding"]
return self.loss_fct(rep, labels)
from .. import util
from torch import nn, Tensor
from typing import Iterable, Dict
class MarginMSELoss(nn.Module):
def __init__(self, model, similarity_fct=util.pairwise_dot_score):
"""
Compute the MSE loss between the ``|sim(Query, Pos) - sim(Query, Neg)|`` and ``|gold_sim(Query, Pos) - gold_sim(Query, Neg)|``.
By default, sim() is the dot-product. The gold_sim is often the similarity score from a teacher model.
In contrast to :class:`MultipleNegativesRankingLoss`, the two passages do not have to be strictly positive and negative,
both can be relevant or not relevant for a given query. This can be an advantage of MarginMSELoss over
MultipleNegativesRankingLoss, but note that the MarginMSELoss is much slower to train. With MultipleNegativesRankingLoss,
with a batch size of 64, we compare one query against 128 passages. With MarginMSELoss, we compare a query only
against two passages.
:param model: SentenceTransformerModel
:param similarity_fct: Which similarity function to use.
References:
- For more details, please refer to https://arxiv.org/abs/2010.02666.
- `Training Examples > MS MARCO <../../examples/training/ms_marco/README.html>`_
- `Unsupervised Learning > Domain Adaptation <../../examples/domain_adaptation/README.html>`_
Requirements:
1. (query, passage_one, passage_two) triplets
2. Usually used with a finetuned teacher M in a knowledge distillation setup
Relations:
- :class:`MSELoss` is equivalent to this loss, but without a margin through the negative pair.
Inputs:
+-----------------------------------------------+-----------------------------------------------+
| Texts | Labels |
+===============================================+===============================================+
| (query, passage_one, passage_two) triplets | M(query, passage_one) - M(query, passage_two) |
+-----------------------------------------------+-----------------------------------------------+
Example:
::
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.util import pairwise_dot_score
from torch.utils.data import DataLoader
import torch
student_model = SentenceTransformer('sentence-transformers/distilbert-base-nli-mean-tokens')
teacher_model = SentenceTransformer('sentence-transformers/bert-base-nli-stsb-mean-tokens')
train_examples = [
['The first query', 'The first positive passage', 'The first negative passage'],
['The second query', 'The second positive passage', 'The second negative passage'],
['The third query', 'The third positive passage', 'The third negative passage'],
]
train_batch_size = 1
encoded = torch.tensor([teacher_model.encode(x).tolist() for x in train_examples])
labels = pairwise_dot_score(encoded[:, 0], encoded[:, 1]) - pairwise_dot_score(encoded[:, 0], encoded[:, 2])
train_input_examples = [InputExample(texts=x, label=labels[i]) for i, x in enumerate(train_examples)]
train_dataloader = DataLoader(train_input_examples, shuffle=True, batch_size=train_batch_size)
train_loss = losses.MarginMSELoss(model=student_model)
student_model.fit(
[(train_dataloader, train_loss)],
epochs=10,
)
"""
super(MarginMSELoss, self).__init__()
self.model = model
self.similarity_fct = similarity_fct
self.loss_fct = nn.MSELoss()
def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):
# sentence_features: query, positive passage, negative passage
reps = [self.model(sentence_feature)["sentence_embedding"] for sentence_feature in sentence_features]
embeddings_query = reps[0]
embeddings_pos = reps[1]
embeddings_neg = reps[2]
scores_pos = self.similarity_fct(embeddings_query, embeddings_pos)
scores_neg = self.similarity_fct(embeddings_query, embeddings_neg)
margin_pred = scores_pos - scores_neg
return self.loss_fct(margin_pred, labels)
from typing import Any, Dict, List, Optional, Union
from torch.nn import Module
from sentence_transformers.SentenceTransformer import SentenceTransformer
from sentence_transformers.losses import AdaptiveLayerLoss, MatryoshkaLoss
class Matryoshka2dLoss(AdaptiveLayerLoss):
def __init__(
self,
model: SentenceTransformer,
loss: Module,
matryoshka_dims: List[int],
matryoshka_weights: Optional[List[Union[float, int]]] = None,
n_layers_per_step: int = 1,
n_dims_per_step: int = 1,
last_layer_weight: float = 1.0,
prior_layers_weight: float = 1.0,
kl_div_weight: float = 1.0,
kl_temperature: float = 0.3,
) -> None:
"""
The Matryoshka2dLoss can be seen as a loss *modifier* that combines the :class:`AdaptiveLayerLoss` and the
:class:`MatryoshkaLoss`. This allows you to train an embedding model that 1) allows users to specify the number
of model layers to use, and 2) allows users to specify the output dimensions to use.
The former is useful for when you want users to have the option to lower the number of layers used to improve
their inference speed and memory usage, and the latter is useful for when you want users to have the option to
lower the output dimensions to improve the efficiency of their downstream tasks (e.g. retrieval) or to lower
their storage costs.
Note, this uses `n_layers_per_step=1` and `n_dims_per_step=1` as default, following the original 2DMSE
implementation.
:param model: SentenceTransformer model
:param loss: The loss function to be used, e.g. :class:`MultipleNegativesRankingLoss`, :class:`CoSENTLoss`, etc.
:param matryoshka_dims: A list of embedding dimensions to be used for the loss function, e.g. [768, 512, 256, 128, 64].
:param matryoshka_weights: A list of weights to be used for the loss function, e.g. [1, 1, 1, 1, 1]. If None, then the
weights will be set to 1 for all dimensions.
:param n_layers_per_step: The number of layers to use per step. If -1, then all layers are used. If > 0, then
a random sample of n_layers_per_step layers are used per step. The 2DMSE paper uses `n_layers_per_step=1`.
The default value is -1.
:param n_dims_per_step: The number of dimensions to use per step. If -1, then all dimensions are used. If > 0, then
a random sample of n_dims_per_step dimensions are used per step. The default value is -1.
:param last_layer_weight: The weight to use for the loss of the final layer. Increase this to focus more on the
performance when using all layers. The default value is 1.0.
:param prior_layers_weight: The weight to use for the loss of the prior layers. Increase this to focus more on
the performance when using fewer layers. The default value is 1.0.
:param kl_div_weight: The weight to use for the KL-divergence loss that is used to make the prior layers match
that of the last layer. Increase this to focus more on the performance when using fewer layers. The default
value is 1.0.
:param kl_temperature: The temperature to use for the KL-divergence loss. If 0, then the KL-divergence loss is
not used. The default value is 1.0.
References:
- See the 2D Matryoshka Sentence Embeddings (2DMSE) paper: https://arxiv.org/abs/2402.14776
- `Matryoshka Embeddings <../../examples/training/matryoshka/README.html>`_
- `Adaptive Layers <../../examples/training/adaptive_layer/README.html>`_
Requirements:
1. The base loss cannot be :class:`CachedMultipleNegativesRankingLoss`.
Relations:
- :class:`MatryoshkaLoss` is used in this loss, and it is responsible for the dimensionality reduction.
- :class:`AdaptiveLayerLoss` is used in this loss, and it is responsible for the layer reduction.
Input:
+---------------------------------------+--------+
| Texts | Labels |
+=======================================+========+
| any | any |
+---------------------------------------+--------+
Example:
::
from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import DataLoader
model = SentenceTransformer('microsoft/mpnet-base')
train_examples = [
InputExample(texts=['Anchor 1', 'Positive 1']),
InputExample(texts=['Anchor 2', 'Positive 2']),
]
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=32)
train_loss = losses.MultipleNegativesRankingLoss(model=model)
train_loss = losses.Matryoshka2dLoss(model, train_loss, [768, 512, 256, 128, 64])
model.fit(
[(train_dataloader, train_loss)],
epochs=10,
)
"""
matryoshka_loss = MatryoshkaLoss(
model,
loss,
matryoshka_dims,
matryoshka_weights=matryoshka_weights,
n_dims_per_step=n_dims_per_step,
)
super().__init__(
model,
matryoshka_loss,
n_layers_per_step=n_layers_per_step,
last_layer_weight=last_layer_weight,
prior_layers_weight=prior_layers_weight,
kl_div_weight=kl_div_weight,
kl_temperature=kl_temperature,
)
def get_config_dict(self) -> Dict[str, Any]:
return {
**super().get_config_dict(),
**self.loss.get_config_dict(),
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment