make_multilingual.py 10.8 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""
This script contains an example how to extend an existent sentence embedding model to new languages.

Given a (monolingual) teacher model you would like to extend to new languages, which is specified in the teacher_model_name
variable. We train a multilingual student model to imitate the teacher model (variable student_model_name)
on multiple languages.

For training, you need parallel sentence data (machine translation training data). You need tab-seperated files (.tsv)
with the first column a sentence in a language understood by the teacher model, e.g. English,
and the further columns contain the according translations for languages you want to extend to.

This scripts downloads automatically the parallel sentences corpus. This corpus contains transcripts from
talks translated to 100+ languages. For other parallel data, see get_parallel_data_[].py scripts

Further information can be found in our paper:
Making Monolingual Sentence Embeddings Multilingual using Knowledge Distillation
https://arxiv.org/abs/2004.09813
"""

Rayyyyy's avatar
Rayyyyy committed
20
21
import logging
import traceback
Rayyyyy's avatar
Rayyyyy committed
22
23
24
from datetime import datetime

import numpy as np
Rayyyyy's avatar
Rayyyyy committed
25
26
27
28
29
30
31
32
33
34
35
36

from datasets import DatasetDict, load_dataset
from sentence_transformers import LoggingHandler, SentenceTransformer
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    MSEEvaluator,
    SequentialEvaluator,
    TranslationEvaluator,
)
from sentence_transformers.losses import MSELoss
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
Rayyyyy's avatar
Rayyyyy committed
37
38
39
40
41
42
43

logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
logger = logging.getLogger(__name__)


Rayyyyy's avatar
Rayyyyy committed
44
45
46
47
# The teacher model is monolingual, we use it for English embeddings
teacher_model_name = "paraphrase-distilroberta-base-v2"
# The student model is multilingual, we train it such that embeddings of non-English texts mimic the teacher model's English embeddings
student_model_name = "xlm-roberta-base"
Rayyyyy's avatar
Rayyyyy committed
48

Rayyyyy's avatar
Rayyyyy committed
49
student_max_seq_length = 128  # Student model max. lengths for inputs (number of word pieces)
Rayyyyy's avatar
Rayyyyy committed
50
51
52
53
train_batch_size = 64  # Batch size for training
inference_batch_size = 64  # Batch size at inference
max_sentences_per_language = 500000  # Maximum number of  parallel sentences for training

Rayyyyy's avatar
Rayyyyy committed
54
55
num_train_epochs = 5  # Train for x epochs
num_evaluation_steps = 5000  # Evaluate performance after every xxxx steps
Rayyyyy's avatar
Rayyyyy committed
56
57
58
59


# Define the language codes you would like to extend the model to
source_languages = set(["en"])  # Our teacher model accepts English (en) sentences
Rayyyyy's avatar
Rayyyyy committed
60
61
# We want to extend the model to these new languages. For language codes, see the header of the train file
target_languages = set(["de", "es", "it", "fr", "ar", "tr"])
Rayyyyy's avatar
Rayyyyy committed
62
63


Rayyyyy's avatar
Rayyyyy committed
64
output_dir = (
Rayyyyy's avatar
Rayyyyy committed
65
66
67
68
69
70
    "output/make-multilingual-"
    + "-".join(sorted(list(source_languages)) + sorted(list(target_languages)))
    + "-"
    + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)

Rayyyyy's avatar
Rayyyyy committed
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# 1a. Here we define our SentenceTransformer teacher model.
teacher_model = SentenceTransformer(teacher_model_name)
# If we want, we can limit the maximum sequence length for the model
# teacher_model.max_seq_length = 128
logging.info(f"Teacher model: {teacher_model}")

# 1b. Here we define our SentenceTransformer student model. If not already a Sentence Transformer model,
# it will automatically create one with "mean" pooling.
student_model = SentenceTransformer(student_model_name)
# If we want, we can limit the maximum sequence length for the model
student_model.max_seq_length = student_max_seq_length
logging.info(f"Student model: {student_model}")

# 2. Load the parallel sentences training dataset: https://huggingface.co/datasets?other=sentence-transformers&sort=trending&search=parallel-sentences
# NOTE: We can also use multiple datasets if we want
dataset_to_use = "sentence-transformers/parallel-sentences-talks"
# dataset_to_use = "sentence-transformers/parallel-sentences-europarl"
# dataset_to_use = "sentence-transformers/parallel-sentences-global-voices"
# dataset_to_use = "sentence-transformers/parallel-sentences-muse"
# dataset_to_use = "sentence-transformers/parallel-sentences-jw300"
# dataset_to_use = "sentence-transformers/parallel-sentences-news-commentary"
# dataset_to_use = "sentence-transformers/parallel-sentences-opensubtitles"
# dataset_to_use = "sentence-transformers/parallel-sentences-tatoeba"
# dataset_to_use = "sentence-transformers/parallel-sentences-wikimatrix"
# dataset_to_use = "sentence-transformers/parallel-sentences-wikititles"
train_dataset_dict = DatasetDict()
eval_dataset_dict = DatasetDict()
Rayyyyy's avatar
Rayyyyy committed
98
99
for source_lang in source_languages:
    for target_lang in target_languages:
Rayyyyy's avatar
Rayyyyy committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        subset = f"{source_lang}-{target_lang}"
        try:
            train_dataset = load_dataset(dataset_to_use, subset, split="train")
            if len(train_dataset) > max_sentences_per_language:
                train_dataset = train_dataset.select(range(max_sentences_per_language))
        except Exception as exc:
            logging.error(f"Could not load dataset {dataset_to_use}/{source_lang}-{target_lang}: {exc}")
            continue

        try:
            eval_dataset = load_dataset(dataset_to_use, subset, split="dev")
            if len(eval_dataset) > 1000:
                eval_dataset = eval_dataset.select(range(1000))
        except Exception:
            logging.info(
                f"Could not load dataset {dataset_to_use}/{source_lang}-{target_lang} dev split, splitting 1k samples from train"
Rayyyyy's avatar
Rayyyyy committed
116
            )
Rayyyyy's avatar
Rayyyyy committed
117
118
119
            dataset = train_dataset.train_test_split(test_size=1000, shuffle=True)
            train_dataset = dataset["train"]
            eval_dataset = dataset["test"]
Rayyyyy's avatar
Rayyyyy committed
120

Rayyyyy's avatar
Rayyyyy committed
121
122
123
        train_dataset_dict[subset] = train_dataset
        eval_dataset_dict[subset] = eval_dataset
logging.info(train_dataset_dict)
Rayyyyy's avatar
Rayyyyy committed
124
125


Rayyyyy's avatar
Rayyyyy committed
126
127
128
129
130
131
132
133
# We want the student EN embeddings to be similar to the teacher EN embeddings and
# the student non-EN embeddings to be similar to the teacher EN embeddings
def prepare_dataset(batch):
    return {
        "english": batch["english"],
        "non_english": batch["non_english"],
        "label": teacher_model.encode(batch["english"], batch_size=inference_batch_size, show_progress_bar=False),
    }
Rayyyyy's avatar
Rayyyyy committed
134
135


Rayyyyy's avatar
Rayyyyy committed
136
137
138
column_names = list(train_dataset_dict.values())[0].column_names
train_dataset_dict = train_dataset_dict.map(
    prepare_dataset, batched=True, batch_size=30000, remove_columns=column_names
Rayyyyy's avatar
Rayyyyy committed
139
)
Rayyyyy's avatar
Rayyyyy committed
140
logging.info("Prepared datasets for training:", train_dataset_dict)
Rayyyyy's avatar
Rayyyyy committed
141

Rayyyyy's avatar
Rayyyyy committed
142
143
144
145
# 3. Define our training loss
# MSELoss (https://sbert.net/docs/package_reference/sentence_transformer/losses.html#mseloss) needs one text columns and one
# column with embeddings from the teacher model
train_loss = MSELoss(model=student_model)
Rayyyyy's avatar
Rayyyyy committed
146

Rayyyyy's avatar
Rayyyyy committed
147
148
# 4. Define evaluators for use during training. This is useful to keep track of alongside the evaluation loss.
evaluators = []
Rayyyyy's avatar
Rayyyyy committed
149

Rayyyyy's avatar
Rayyyyy committed
150
151
for subset, eval_dataset in eval_dataset_dict.items():
    logger.info(f"Creating evaluators for {subset}")
Rayyyyy's avatar
Rayyyyy committed
152
153

    # Mean Squared Error (MSE) measures the (euclidean) distance between teacher and student embeddings
Rayyyyy's avatar
Rayyyyy committed
154
155
156
157
    dev_mse = MSEEvaluator(
        source_sentences=eval_dataset["english"],
        target_sentences=eval_dataset["non_english"],
        name=subset,
Rayyyyy's avatar
Rayyyyy committed
158
159
160
161
162
        teacher_model=teacher_model,
        batch_size=inference_batch_size,
    )
    evaluators.append(dev_mse)

Rayyyyy's avatar
Rayyyyy committed
163
164
165
166
167
168
169
    # TranslationEvaluator computes the embeddings for all parallel sentences. It then check if the embedding of
    # source[i] is the closest to target[i] out of all available target sentences
    dev_trans_acc = TranslationEvaluator(
        source_sentences=eval_dataset["english"],
        target_sentences=eval_dataset["non_english"],
        name=subset,
        batch_size=inference_batch_size,
Rayyyyy's avatar
Rayyyyy committed
170
171
172
    )
    evaluators.append(dev_trans_acc)

Rayyyyy's avatar
Rayyyyy committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    # Try to load this subset from STS17
    test_dataset = None
    try:
        test_dataset = load_dataset("mteb/sts17-crosslingual-sts", subset, split="test")
    except Exception:
        try:
            test_dataset = load_dataset("mteb/sts17-crosslingual-sts", f"{subset[3:]}-{subset[:2]}", split="test")
            subset = f"{subset[3:]}-{subset[:2]}"
        except Exception:
            pass
    if test_dataset:
        test_evaluator = EmbeddingSimilarityEvaluator(
            sentences1=test_dataset["sentence1"],
            sentences2=test_dataset["sentence2"],
            scores=[score / 5.0 for score in test_dataset["score"]],  # Convert 0-5 scores to 0-1 scores
            batch_size=inference_batch_size,
            name=f"sts17-{subset}-test",
            show_progress_bar=False,
        )
        evaluators.append(test_evaluator)

evaluator = SequentialEvaluator(evaluators, main_score_function=lambda scores: np.mean(scores))
# Now also prepare the evaluation datasets for training
eval_dataset_dict = eval_dataset_dict.map(prepare_dataset, batched=True, batch_size=30000, remove_columns=column_names)

# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
    # Required parameter:
    output_dir=output_dir,
    # Optional training parameters:
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=train_batch_size,
    warmup_ratio=0.1,
    fp16=True,  # Set to False if you get an error that your GPU can't run on FP16
    bf16=False,  # Set to True if you have a GPU that supports BF16
    learning_rate=2e-5,
    # Optional tracking/debugging parameters:
    eval_strategy="steps",
    eval_steps=num_evaluation_steps,
    save_strategy="steps",
    save_steps=num_evaluation_steps,
    save_total_limit=2,
    logging_steps=100,
    run_name=f"multilingual-{'-'.join(source_languages)}-{'-'.join(target_languages)}",  # Will be used in W&B if `wandb` is installed
)
Rayyyyy's avatar
Rayyyyy committed
219

Rayyyyy's avatar
Rayyyyy committed
220
221
222
223
224
225
226
227
# 6. Create the trainer & start training
trainer = SentenceTransformerTrainer(
    model=student_model,
    args=args,
    train_dataset=train_dataset_dict,
    eval_dataset=eval_dataset_dict,
    loss=train_loss,
    evaluator=evaluator,
Rayyyyy's avatar
Rayyyyy committed
228
)
Rayyyyy's avatar
Rayyyyy committed
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
trainer.train()

# 7. Save the trained & evaluated model locally
final_output_dir = f"{output_dir}/final"
student_model.save(final_output_dir)

# 8. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
model_name = student_model_name if "/" not in student_model_name else student_model_name.split("/")[-1]
try:
    student_model.push_to_hub(f"{model_name}-multilingual-{'-'.join(source_languages)}-{'-'.join(target_languages)}")
except Exception:
    logging.error(
        f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
        f"`huggingface-cli login`, followed by loading the model using `model = SentenceTransformer({final_output_dir!r})` "
        f"and saving it using `model.push_to_hub('{model_name}-multilingual-{'-'.join(source_languages)}-{'-'.join(target_languages)}')`."
    )