2d_matryoshka_nli.py 4.98 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
"""
The system trains BERT (or any other transformer model like RoBERTa, DistilBERT etc.) on the SNLI + MultiNLI (AllNLI) dataset
with MatryoshkaLoss using MultipleNegativesRankingLoss. This trains a model at output dimensions [768, 512, 256, 128, 64].
Entailments are positive pairs and the contradiction on AllNLI dataset is added as a hard negative.
At every 10% training steps, the model is evaluated on the STS benchmark dataset

Usage:
python 2d_matryoshka_nli.py

OR
python 2d_matryoshka_nli.py pretrained_transformer_model_name
"""

import logging
import sys
Rayyyyy's avatar
Rayyyyy committed
16
17
import traceback
from datetime import datetime
Rayyyyy's avatar
Rayyyyy committed
18

Rayyyyy's avatar
Rayyyyy committed
19
20
21
22
23
24
from datasets import load_dataset
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
    SentenceTransformerTrainingArguments,
    losses,
Rayyyyy's avatar
Rayyyyy committed
25
)
Rayyyyy's avatar
Rayyyyy committed
26
27
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator, SimilarityFunction
from sentence_transformers.training_args import BatchSamplers
Rayyyyy's avatar
Rayyyyy committed
28

Rayyyyy's avatar
Rayyyyy committed
29
30
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
Rayyyyy's avatar
Rayyyyy committed
31

Rayyyyy's avatar
Rayyyyy committed
32
33
34
model_name = sys.argv[1] if len(sys.argv) > 1 else "distilroberta-base"
batch_size = 128  # The larger you select this, the better the results (usually). But it requires more GPU memory
num_train_epochs = 1
Rayyyyy's avatar
Rayyyyy committed
35

Rayyyyy's avatar
Rayyyyy committed
36
37
# Save path of the model
output_dir = f"output/2d_matryoshka_nli_{model_name.replace('/', '-')}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
Rayyyyy's avatar
Rayyyyy committed
38

Rayyyyy's avatar
Rayyyyy committed
39
40
41
42
43
44
# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
# create one with "mean" pooling.
model = SentenceTransformer(model_name)
# If we want, we can limit the maximum sequence length for the model
# model.max_seq_length = 75
logging.info(model)
Rayyyyy's avatar
Rayyyyy committed
45

Rayyyyy's avatar
Rayyyyy committed
46
47
48
49
# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
logging.info(train_dataset)
Rayyyyy's avatar
Rayyyyy committed
50

Rayyyyy's avatar
Rayyyyy committed
51
52
# If you wish, you can limit the number of training samples
# train_dataset = train_dataset.select(range(5000))
Rayyyyy's avatar
Rayyyyy committed
53

Rayyyyy's avatar
Rayyyyy committed
54
55
56
# 3. Define our training loss
inner_train_loss = losses.MultipleNegativesRankingLoss(model)
train_loss = losses.Matryoshka2dLoss(model, inner_train_loss, [768, 512, 256, 128, 64])
Rayyyyy's avatar
Rayyyyy committed
57

Rayyyyy's avatar
Rayyyyy committed
58
59
# 4. Define an evaluator for use during training. This is useful to keep track of alongside the evaluation loss.
stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
Rayyyyy's avatar
Rayyyyy committed
60
dev_evaluator = EmbeddingSimilarityEvaluator(
Rayyyyy's avatar
Rayyyyy committed
61
62
63
    sentences1=stsb_eval_dataset["sentence1"],
    sentences2=stsb_eval_dataset["sentence2"],
    scores=stsb_eval_dataset["score"],
Rayyyyy's avatar
Rayyyyy committed
64
65
66
67
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)

Rayyyyy's avatar
Rayyyyy committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    # Optional training parameters:
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=2,
    logging_steps=100,
    run_name="2d-matryoshka-nli",  # Will be used in W&B if `wandb` is installed
)
Rayyyyy's avatar
Rayyyyy committed
89

Rayyyyy's avatar
Rayyyyy committed
90
91
92
93
94
95
96
# 6. Create the trainer & start training
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
Rayyyyy's avatar
Rayyyyy committed
97
98
    evaluator=dev_evaluator,
)
Rayyyyy's avatar
Rayyyyy committed
99
trainer.train()
Rayyyyy's avatar
Rayyyyy committed
100

Rayyyyy's avatar
Rayyyyy committed
101
102
# 7. Evaluate the model performance on the STS Benchmark test dataset
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
Rayyyyy's avatar
Rayyyyy committed
103
test_evaluator = EmbeddingSimilarityEvaluator(
Rayyyyy's avatar
Rayyyyy committed
104
105
106
    sentences1=test_dataset["sentence1"],
    sentences2=test_dataset["sentence2"],
    scores=test_dataset["score"],
Rayyyyy's avatar
Rayyyyy committed
107
108
109
    main_similarity=SimilarityFunction.COSINE,
    name="sts-test",
)
Rayyyyy's avatar
Rayyyyy committed
110
test_evaluator(model)
Rayyyyy's avatar
Rayyyyy committed
111

Rayyyyy's avatar
Rayyyyy committed
112
113
114
# 8. Save the trained & evaluated model locally
final_output_dir = f"{output_dir}/final"
model.save(final_output_dir)
Rayyyyy's avatar
Rayyyyy committed
115

Rayyyyy's avatar
Rayyyyy committed
116
# 9. (Optional) save the model to the Hugging Face Hub!
Rayyyyy's avatar
Rayyyyy committed
117
118
119
120
121
122
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
try:
    model.push_to_hub(f"{model_name}-nli-2d-matryoshka")
except Exception:
    logging.error(
Rayyyyy's avatar
Rayyyyy committed
123
124
        f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
        f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
Rayyyyy's avatar
Rayyyyy committed
125
126
        f"and saving it using `model.push_to_hub('{model_name}-nli-2d-matryoshka')`."
    )