matryoshka_nli.py 5.5 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
"""
The system trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) on the SNLI + MultiNLI (AllNLI) dataset
with MatryoshkaLoss using MultipleNegativesRankingLoss. This trains a model at output dimensions [768, 512, 256, 128, 64].
Entailments are positive pairs and the contradiction on AllNLI dataset is added as a hard negative.
Rayyyyy's avatar
Rayyyyy committed
5
At every 10% training steps, the model is evaluated on the STS benchmark dataset at the different output dimensions.
Rayyyyy's avatar
Rayyyyy committed
6
7
8
9
10
11
12
13
14
15

Usage:
python matryoshka_nli.py

OR
python matryoshka_nli.py pretrained_transformer_model_name
"""

import logging
import sys
Rayyyyy's avatar
Rayyyyy committed
16
17
import traceback
from datetime import datetime
Rayyyyy's avatar
Rayyyyy committed
18

Rayyyyy's avatar
Rayyyyy committed
19
20
21
22
23
24
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses,
Rayyyyy's avatar
Rayyyyy committed
25
)
Rayyyyy's avatar
Rayyyyy committed
26
27
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SequentialEvaluator, SimilarityFunction
from sentence_transformers.training_args import BatchSamplers
Rayyyyy's avatar
Rayyyyy committed
28

Rayyyyy's avatar
Rayyyyy committed
29
30
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
Rayyyyy's avatar
Rayyyyy committed
31

Rayyyyy's avatar
Rayyyyy committed
32
33
34
35
model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base"
batch_size = 128  # The larger you select this, the better the results (usually). But it requires more GPU memory
num_train_epochs = 1
matryoshka_dims = [768, 512, 256, 128, 64]
Rayyyyy's avatar
Rayyyyy committed
36

Rayyyyy's avatar
Rayyyyy committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Save path of the model
output_dir = f"output/matryoshka_nli_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
# create one with "mean" pooling.
model = SentenceTransformer(model_name)
# If we want, we can limit the maximum sequence length for the model
# model.max_seq_length = 75
logging.info(model)

# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
logging.info(train_dataset)

# If you wish, you can limit the number of training samples
# train_dataset = train_dataset.select(range(5000))

# 3. Define our training loss
inner_train_loss = losses.MultipleNegativesRankingLoss(model)
train_loss = losses.MatryoshkaLoss(model, inner_train_loss, matryoshka_dims=matryoshka_dims)

# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss.
stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
evaluators = []
for dim in matryoshka_dims:
    evaluators.append(
        EmbeddingSimilarityEvaluator(
            sentences1=stsb_eval_dataset["sentence1"],
            sentences2=stsb_eval_dataset["sentence2"],
            scores=stsb_eval_dataset["score"],
            main_similarity=SimilarityFunction.COSINE,
            name=f"sts-dev-{dim}",
            truncate_dim=dim,
Rayyyyy's avatar
Rayyyyy committed
71
        )
Rayyyyy's avatar
Rayyyyy committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    )
dev_evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: scores[0])

# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    # Optional training parameters:
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="matryoshka-nli",  # Will be used in W&B if `wandb` is installed
Rayyyyy's avatar
Rayyyyy committed
95
96
)

Rayyyyy's avatar
Rayyyyy committed
97
98
99
100
101
102
103
# 6. Create the trainer & start training
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
Rayyyyy's avatar
Rayyyyy committed
104
105
    evaluator=dev_evaluator,
)
Rayyyyy's avatar
Rayyyyy committed
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
trainer.train()

# 7. Evaluate the model performance on the STS Benchmark test dataset
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
evaluators = []
for dim in matryoshka_dims:
    evaluators.append(
        EmbeddingSimilarityEvaluator(
            sentences1=test_dataset["sentence1"],
            sentences2=test_dataset["sentence2"],
            scores=test_dataset["score"],
            main_similarity=SimilarityFunction.COSINE,
            name=f"sts-test-{dim}",
            truncate_dim=dim,
        )
    )
test_evaluator = SequentialEvaluator(evaluators)
test_evaluator(model)
Rayyyyy's avatar
Rayyyyy committed
124

Rayyyyy's avatar
Rayyyyy committed
125
126
127
# 8. Save the trained & evaluated model locally
final_output_dir = f"{output_dir}/final"
model.save(final_output_dir)
Rayyyyy's avatar
Rayyyyy committed
128

Rayyyyy's avatar
Rayyyyy committed
129
# 9. (Optional) save the model to the Hugging Face Hub!
Rayyyyy's avatar
Rayyyyy committed
130
131
132
133
134
135
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
try:
    model.push_to_hub(f"{model_name}-nli-matryoshka")
except Exception:
    logging.error(
Rayyyyy's avatar
Rayyyyy committed
136
137
        f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
        f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
Rayyyyy's avatar
Rayyyyy committed
138
139
        f"and saving it using `model.push_to_hub('{model_name}-nli-matryoshka')`."
    )