evaluation_translation_matching.py 2.1 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""
Given a tab separated file (.tsv) with parallel sentences, where the second column is the translation of the sentence in the first column, for example, in the format:
src1    trg1
src2    trg2
...

where trg_i is the translation of src_i.

Given src_i, the TranslationEvaluator checks which trg_j has the highest similarity using cosine similarity. If i == j, we assume
a match, i.e., the correct translation has been found for src_i out of all possible target sentences.

It then computes an accuracy over all possible source sentences src_i. Equivalently, it computes also the accuracy for the other direction.

A high accuracy score indicates that the model is able to find the correct translation out of a large pool with sentences.

Usage:
python [model_name_or_path] [parallel-file1] [parallel-file2] ...

For example:
python distiluse-base-multilingual-cased  talks-en-de.tsv.gz

See the training_multilingual/get_parallel_data_...py scripts for getting parallel sentence data from different sources
"""

from sentence_transformers import SentenceTransformer, evaluation, LoggingHandler
import sys
import gzip
import os
import logging


logging.basicConfig(
    format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)

logger = logging.getLogger(__name__)

model_name = sys.argv[1]
filepaths = sys.argv[2:]
inference_batch_size = 32

model = SentenceTransformer(model_name)


for filepath in filepaths:
    src_sentences = []
    trg_sentences = []
    with gzip.open(filepath, "rt", encoding="utf8") if filepath.endswith(".gz") else open(
        filepath, "r", encoding="utf8"
    ) as fIn:
        for line in fIn:
            splits = line.strip().split("\t")
            if len(splits) >= 2:
                src_sentences.append(splits[0])
                trg_sentences.append(splits[1])

    logger.info(os.path.basename(filepath) + ": " + str(len(src_sentences)) + " sentence pairs")
    dev_trans_acc = evaluation.TranslationEvaluator(
        src_sentences, trg_sentences, name=os.path.basename(filepath), batch_size=inference_batch_size
    )
    dev_trans_acc(model)