"vscode:/vscode.git/clone" did not exist on "42caa660ef5897a25e2cdeba97dea3fd8ea75926"
model_distillation.py 10 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
"""
This file contains an example how to make a SentenceTransformer model faster and lighter.

This is achieved by using Knowledge Distillation: We use a well working teacher model to train
a fast and light student model. The student model learns to imitate the produced
sentence embeddings from the teacher. We train this on a diverse set of sentences we got
from SNLI + Multi+NLI + Wikipedia.

After the distillation is finished, the student model produce nearly the same embeddings as the
teacher, however, it will be much faster.

The script implements to options two options to initialize the student:
Option 1: Train a light transformer model like TinyBERT to imitate the teacher
Option 2: We take the teacher model and keep only certain layers, for example, only 4 layers.

Option 2) works usually better, as we keep most of the weights from the teacher. In Option 1, we have to tune all
weights in the student from scratch.

There is a performance - speed trade-off. However, we found that a student with 4 instead of 12 layers keeps about 99.4%
of the teacher performance, while being 2.3 times faster.
"""

import logging
Rayyyyy's avatar
Rayyyyy committed
24
import traceback
Rayyyyy's avatar
Rayyyyy committed
25
from datetime import datetime
Rayyyyy's avatar
Rayyyyy committed
26
27

import pandas as pd
Rayyyyy's avatar
Rayyyyy committed
28
import torch
Rayyyyy's avatar
Rayyyyy committed
29
from sklearn.decomposition import PCA
Rayyyyy's avatar
Rayyyyy committed
30

Rayyyyy's avatar
Rayyyyy committed
31
32
33
34
35
36
from datasets import Dataset, concatenate_datasets, load_dataset
from sentence_transformers import LoggingHandler, SentenceTransformer, evaluation, losses, models
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
Rayyyyy's avatar
Rayyyyy committed
37
38
39
40
41
42
43
44
45
46
47
48

#### Just some code to print debug information to stdout
logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
#### /print debug information to stdout


# Teacher Model: Model we want to distill to a smaller model
teacher_model_name = "stsb-roberta-base-v2"
teacher_model = SentenceTransformer(teacher_model_name)

Rayyyyy's avatar
Rayyyyy committed
49
output_dir = "output/model-distillation-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
Rayyyyy's avatar
Rayyyyy committed
50

Rayyyyy's avatar
Rayyyyy committed
51
52
53
54
# We will train a small model like TinyBERT to imitate the teacher.
# You can find some small BERT models here: https://huggingface.co/nreimers
student_model_name = "nreimers/TinyBERT_L-4_H-312_v2"
student_model = SentenceTransformer(student_model_name)
Rayyyyy's avatar
Rayyyyy committed
55
56
57
58
59
60
61
62
63
64
65
66
67
68

inference_batch_size = 64
train_batch_size = 64

# We use AllNLI as a source of sentences for the distillation
nli_dataset_path = "datasets/AllNLI.tsv.gz"

# Further, we use sentences extracted from the English Wikipedia to train the distillation
wikipedia_dataset_path = "datasets/wikipedia-en-sentences.txt.gz"

# We use the STS benchmark dataset to see how much performance we loose
sts_dataset_path = "datasets/stsbenchmark.tsv.gz"


Rayyyyy's avatar
Rayyyyy committed
69
70
71
72
73
logging.info("Load the AllNLI dataset")
# Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
nli_train_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="train")
nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "pair-score", split="dev")
# Concatenate all sentences into a new column "sentence"
Rayyyyy's avatar
Rayyyyy committed
74
75


Rayyyyy's avatar
Rayyyyy committed
76
77
def combine_sentences(batch):
    return {"sentence": batch["sentence1"] + batch["sentence2"]}
Rayyyyy's avatar
Rayyyyy committed
78
79


Rayyyyy's avatar
Rayyyyy committed
80
81
82
83
nli_train_dataset = nli_train_dataset.map(
    combine_sentences, batched=True, remove_columns=nli_train_dataset.column_names
)
nli_eval_dataset = nli_eval_dataset.map(combine_sentences, batched=True, remove_columns=nli_eval_dataset.column_names)
Rayyyyy's avatar
Rayyyyy committed
84
85


Rayyyyy's avatar
Rayyyyy committed
86
87
88
89
def deduplicate(dataset):
    df = pd.DataFrame(dataset)
    df = df.drop_duplicates()
    return Dataset.from_pandas(df, preserve_index=False)
Rayyyyy's avatar
Rayyyyy committed
90
91


Rayyyyy's avatar
Rayyyyy committed
92
93
94
nli_train_dataset = deduplicate(nli_train_dataset)
nli_eval_dataset = deduplicate(nli_eval_dataset)
logging.info(nli_train_dataset)
Rayyyyy's avatar
Rayyyyy committed
95
96


Rayyyyy's avatar
Rayyyyy committed
97
98
99
100
101
logging.info("Load the STSB dataset")
# Load the STSB eval/test datasets: https://huggingface.co/datasets/sentence-transformers/stsb
stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
stsb_test_dataset = load_dataset("sentence-transformers/stsb", split="test")
logging.info(stsb_eval_dataset)
Rayyyyy's avatar
Rayyyyy committed
102
103


Rayyyyy's avatar
Rayyyyy committed
104
105
106
107
108
109
110
111
logging.info("Load the Wikipedia dataset")
# Load the Wikipedia dataset: https://huggingface.co/datasets/sentence-transformers/wikipedia-en-sentences
wikipedia_train_dataset = load_dataset("sentence-transformers/wikipedia-en-sentences", split="train")
# Take 5000 random sentences from the Wikipedia dataset for evaluation
wikipedia_train_dataset_dict = wikipedia_train_dataset.train_test_split(test_size=5000)
wikipedia_train_dataset = wikipedia_train_dataset_dict["train"]
wikipedia_eval_dataset = wikipedia_train_dataset_dict["test"]
logging.info(wikipedia_train_dataset)
Rayyyyy's avatar
Rayyyyy committed
112
113


Rayyyyy's avatar
Rayyyyy committed
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
# Concatenate the NLI and Wikipedia datasets for training
train_dataset: Dataset = concatenate_datasets([nli_train_dataset, wikipedia_train_dataset])
# Create a relatively small dataset for evaluation
eval_dataset: Dataset = concatenate_datasets(
    [nli_eval_dataset.select(range(5000)), wikipedia_eval_dataset.select(range(5000))]
)

# Create an STSB evaluator
dev_evaluator_stsb = EmbeddingSimilarityEvaluator(
    sentences1=stsb_eval_dataset["sentence1"],
    sentences2=stsb_eval_dataset["sentence2"],
    scores=stsb_eval_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)
logging.info("Teacher Performance")
dev_evaluator_stsb(teacher_model)
Rayyyyy's avatar
Rayyyyy committed
131
132
133
134

# Student model has fewer dimensions. Compute PCA for the teacher to reduce the dimensions
if student_model.get_sentence_embedding_dimension() < teacher_model.get_sentence_embedding_dimension():
    logging.info("Student model has fewer dimensions than the teacher. Compute PCA for down projection")
Rayyyyy's avatar
Rayyyyy committed
135
    pca_sentences = nli_train_dataset[:20000]["sentence"] + wikipedia_train_dataset[:20000]["sentence"]
Rayyyyy's avatar
Rayyyyy committed
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
    pca_embeddings = teacher_model.encode(pca_sentences, convert_to_numpy=True)
    pca = PCA(n_components=student_model.get_sentence_embedding_dimension())
    pca.fit(pca_embeddings)

    # Add Dense layer to teacher that projects the embeddings down to the student embedding size
    dense = models.Dense(
        in_features=teacher_model.get_sentence_embedding_dimension(),
        out_features=student_model.get_sentence_embedding_dimension(),
        bias=False,
        activation_function=torch.nn.Identity(),
    )
    dense.linear.weight = torch.nn.Parameter(torch.tensor(pca.components_))
    teacher_model.add_module("dense", dense)

    logging.info("Teacher Performance with {} dimensions:".format(teacher_model.get_sentence_embedding_dimension()))
Rayyyyy's avatar
Rayyyyy committed
151
    dev_evaluator_stsb(teacher_model)
Rayyyyy's avatar
Rayyyyy committed
152
153


Rayyyyy's avatar
Rayyyyy committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Use the teacher model to get the gold embeddings
def map_embeddings(batch):
    return {
        "label": teacher_model.encode(
            batch["sentence"], batch_size=inference_batch_size, show_progress_bar=False
        ).tolist()
    }


train_dataset = train_dataset.select(range(200000))
train_dataset = train_dataset.map(map_embeddings, batched=True, batch_size=50000)
# Optionally, save the dataset to disk to speed up future runs
train_dataset.save_to_disk("datasets/distillation_train_dataset")
# from datasets import DatasetDict, load_from_disk

# train_dataset = load_from_disk("datasets/distillation_train_dataset")
# if isinstance(train_dataset, DatasetDict):
#     train_dataset = train_dataset["train"]
eval_dataset = eval_dataset.map(map_embeddings, batched=True, batch_size=50000)
Rayyyyy's avatar
Rayyyyy committed
173
174
175
176

train_loss = losses.MSELoss(model=student_model)

# We create an evaluator, that measure the Mean Squared Error (MSE) between the teacher and the student embeddings
Rayyyyy's avatar
Rayyyyy committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
eval_sentences = eval_dataset["sentence"]
dev_evaluator_mse = evaluation.MSEEvaluator(eval_sentences, eval_sentences, teacher_model=teacher_model)
dev_evaluator = evaluation.SequentialEvaluator([dev_evaluator_stsb, dev_evaluator_mse])

# Define the training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    # Optional training parameters:
    num_train_epochs=1,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=train_batch_size,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    metric_for_best_model="eval_sts-dev_spearman_cosine",
    load_best_model_at_end=True,
    learning_rate=1e-4,
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=2,
    logging_steps=100,
    run_name="distillation-layer-reduction",  # Will be used in W&B if `wandb` is installed
)

# Create the trainer & start training
trainer = SentenceTransformerTrainer(
    model=student_model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    loss=train_loss,
    evaluator=dev_evaluator,
Rayyyyy's avatar
Rayyyyy committed
213
)
Rayyyyy's avatar
Rayyyyy committed
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
trainer.train()

# Evaluate the model performance on the STS Benchmark test dataset
test_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=stsb_test_dataset["sentence1"],
    sentences2=stsb_test_dataset["sentence2"],
    scores=stsb_test_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-test",
)
test_evaluator(student_model)

# Save the trained & evaluated model locally
final_output_dir = f"{output_dir}/final"
student_model.save(final_output_dir)

# (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
if "/" in student_model_name:
    student_model_name = student_model_name.split("/")[-1]
if "/" in teacher_model_name:
    teacher_model_name = teacher_model_name.split("/")[-1]
repo_id = f"{student_model_name}-distilled-from-{teacher_model_name}"
try:
    student_model.push_to_hub(repo_id)
except Exception:
    logging.error(
        f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
        f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
        f"and saving it using `model.push_to_hub({repo_id!r})`."
    )