evaluation_translation_matching.py 2.33 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
"""
Rayyyyy's avatar
Rayyyyy committed
2
3
Given a dataset with parallel sentences, one "english" column and one "non_english" column, this script evaluates a model on the translation task.
Given a sentence in the "english" column, the model should find the correct translation in the "non_english" column, based on just the embeddings.
Rayyyyy's avatar
Rayyyyy committed
4
5
6
7

It then computes an accuracy over all possible source sentences src_i. Equivalently, it computes also the accuracy for the other direction.
A high accuracy score indicates that the model is able to find the correct translation out of a large pool with sentences.

Rayyyyy's avatar
Rayyyyy committed
8
9
10
11
12
13
14
Good options for datasets are:
* sentence-transformers/parallel-sentences-wikimatrix
* sentence-transformers/parallel-sentences-tatoeba
* sentence-transformers/parallel-sentences-talks

As these have development sets.

Rayyyyy's avatar
Rayyyyy committed
15
Usage:
Rayyyyy's avatar
Rayyyyy committed
16
python examples/evaluation/evaluation_translation_matching.py [model_name_or_path] [dataset_name] [subset1] [subset2] ...
Rayyyyy's avatar
Rayyyyy committed
17
18

For example:
Rayyyyy's avatar
Rayyyyy committed
19
python examples/evaluation/evaluation_translation_matching.py distiluse-base-multilingual-cased sentence-transformers/parallel-sentences-tatoeba en-ar en-de en-nl
Rayyyyy's avatar
Rayyyyy committed
20
21
22
"""

import logging
Rayyyyy's avatar
Rayyyyy committed
23
import sys
Rayyyyy's avatar
Rayyyyy committed
24

Rayyyyy's avatar
Rayyyyy committed
25
26
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, evaluation
Rayyyyy's avatar
Rayyyyy committed
27

Rayyyyy's avatar
Rayyyyy committed
28
29
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
Rayyyyy's avatar
Rayyyyy committed
30
31

model_name = sys.argv[1]
Rayyyyy's avatar
Rayyyyy committed
32
33
dataset_name = sys.argv[2]
subsets = sys.argv[3:]
Rayyyyy's avatar
Rayyyyy committed
34
35
36
37
inference_batch_size = 32

model = SentenceTransformer(model_name)

Rayyyyy's avatar
Rayyyyy committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
for subset in subsets:
    dataset = load_dataset(dataset_name, subset)
    datasets = {}
    if dataset.column_names == ["train"]:
        num_samples = min(5000, len(dataset["train"]))
        datasets[f"train[:{num_samples}]"].append(dataset["train"].select(range(num_samples)))
    else:
        for split, sub_dataset in dataset.items():
            if split != "train":
                datasets[split] = sub_dataset

    for split, sub_dataset in datasets.items():
        logging.info(f"{dataset_name}, subset={subset}, split={split}, num_samples={len(sub_dataset)}")
        translation_evaluator = evaluation.TranslationEvaluator(
            sub_dataset["english"],
            sub_dataset["non_english"],
            name=f"{dataset_name}-{subset}-{split}",
            batch_size=inference_batch_size,
        )
        translation_evaluator(model)