TranslationEvaluator.py 7.39 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
import csv
Rayyyyy's avatar
Rayyyyy committed
2
3
import logging
import os
Rayyyyy's avatar
Rayyyyy committed
4
5
6
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Optional

Rayyyyy's avatar
Rayyyyy committed
7
8
9
import numpy as np
import torch

Rayyyyy's avatar
Rayyyyy committed
10
11
12
13
14
from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.util import pytorch_cos_sim

if TYPE_CHECKING:
    from sentence_transformers.SentenceTransformer import SentenceTransformer
Rayyyyy's avatar
Rayyyyy committed
15
16
17
18
19
20
21
22
23

logger = logging.getLogger(__name__)


class TranslationEvaluator(SentenceEvaluator):
    """
    Given two sets of sentences in different languages, e.g. (en_1, en_2, en_3...) and (fr_1, fr_2, fr_3, ...),
    and assuming that fr_i is the translation of en_i.
    Checks if vec(en_i) has the highest similarity to vec(fr_i). Computes the accuracy in both directions
Rayyyyy's avatar
Rayyyyy committed
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

    Example:
        ::

            from sentence_transformers import SentenceTransformer
            from sentence_transformers.evaluation import TranslationEvaluator
            from datasets import load_dataset

            # Load a model
            model = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2')

            # Load a parallel sentences dataset
            dataset = load_dataset("sentence-transformers/parallel-sentences-news-commentary", "en-nl", split="train[:1000]")

            # Initialize the TranslationEvaluator using the same texts from two languages
            translation_evaluator = TranslationEvaluator(
                source_sentences=dataset["english"],
                target_sentences=dataset["non_english"],
                name="news-commentary-en-nl",
            )
            results = translation_evaluator(model)
            '''
            Evaluating translation matching Accuracy of the model on the news-commentary-en-nl dataset:
            Accuracy src2trg: 90.80
            Accuracy trg2src: 90.40
            '''
            print(translation_evaluator.primary_metric)
            # => "news-commentary-en-nl_mean_accuracy"
            print(results[translation_evaluator.primary_metric])
            # => 0.906
Rayyyyy's avatar
Rayyyyy committed
54
55
56
57
58
59
60
61
62
63
64
    """

    def __init__(
        self,
        source_sentences: List[str],
        target_sentences: List[str],
        show_progress_bar: bool = False,
        batch_size: int = 16,
        name: str = "",
        print_wrong_matches: bool = False,
        write_csv: bool = True,
Rayyyyy's avatar
Rayyyyy committed
65
        truncate_dim: Optional[int] = None,
Rayyyyy's avatar
Rayyyyy committed
66
67
68
69
70
71
    ):
        """
        Constructs an evaluator based for the dataset

        The labels need to indicate the similarity between the sentences.

Rayyyyy's avatar
Rayyyyy committed
72
73
74
75
76
77
78
79
80
81
        Args:
            source_sentences (List[str]): List of sentences in the source language.
            target_sentences (List[str]): List of sentences in the target language.
            show_progress_bar (bool): Whether to show a progress bar when computing embeddings. Defaults to False.
            batch_size (int): The batch size to compute sentence embeddings. Defaults to 16.
            name (str): The name of the evaluator. Defaults to an empty string.
            print_wrong_matches (bool): Whether to print incorrect matches. Defaults to False.
            write_csv (bool): Whether to write the evaluation results to a CSV file. Defaults to True.
            truncate_dim (int, optional): The dimension to truncate sentence embeddings to. If None, the model's
                current truncation dimension will be used. Defaults to None.
Rayyyyy's avatar
Rayyyyy committed
82
        """
Rayyyyy's avatar
Rayyyyy committed
83
        super().__init__()
Rayyyyy's avatar
Rayyyyy committed
84
85
86
87
88
89
        self.source_sentences = source_sentences
        self.target_sentences = target_sentences
        self.name = name
        self.batch_size = batch_size
        self.show_progress_bar = show_progress_bar
        self.print_wrong_matches = print_wrong_matches
Rayyyyy's avatar
Rayyyyy committed
90
        self.truncate_dim = truncate_dim
Rayyyyy's avatar
Rayyyyy committed
91
92
93
94
95
96
97
98
99

        assert len(self.source_sentences) == len(self.target_sentences)

        if name:
            name = "_" + name

        self.csv_file = "translation_evaluation" + name + "_results.csv"
        self.csv_headers = ["epoch", "steps", "src2trg", "trg2src"]
        self.write_csv = write_csv
Rayyyyy's avatar
Rayyyyy committed
100
        self.primary_metric = "mean_accuracy"
Rayyyyy's avatar
Rayyyyy committed
101

Rayyyyy's avatar
Rayyyyy committed
102
103
104
    def __call__(
        self, model: "SentenceTransformer", output_path: str = None, epoch: int = -1, steps: int = -1
    ) -> Dict[str, float]:
Rayyyyy's avatar
Rayyyyy committed
105
106
        if epoch != -1:
            if steps == -1:
Rayyyyy's avatar
Rayyyyy committed
107
                out_txt = f" after epoch {epoch}"
Rayyyyy's avatar
Rayyyyy committed
108
            else:
Rayyyyy's avatar
Rayyyyy committed
109
                out_txt = f" in epoch {epoch} after {steps} steps"
Rayyyyy's avatar
Rayyyyy committed
110
        else:
Rayyyyy's avatar
Rayyyyy committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
            out_txt = ""
        if self.truncate_dim is not None:
            out_txt += f" (truncated to {self.truncate_dim})"

        logger.info(f"Evaluating translation matching Accuracy of the model on the {self.name} dataset{out_txt}:")

        with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
            embeddings1 = torch.stack(
                model.encode(
                    self.source_sentences,
                    show_progress_bar=self.show_progress_bar,
                    batch_size=self.batch_size,
                    convert_to_numpy=False,
                )
Rayyyyy's avatar
Rayyyyy committed
125
            )
Rayyyyy's avatar
Rayyyyy committed
126
127
128
129
130
131
132
            embeddings2 = torch.stack(
                model.encode(
                    self.target_sentences,
                    show_progress_bar=self.show_progress_bar,
                    batch_size=self.batch_size,
                    convert_to_numpy=False,
                )
Rayyyyy's avatar
Rayyyyy committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
            )

        cos_sims = pytorch_cos_sim(embeddings1, embeddings2).detach().cpu().numpy()

        correct_src2trg = 0
        correct_trg2src = 0

        for i in range(len(cos_sims)):
            max_idx = np.argmax(cos_sims[i])

            if i == max_idx:
                correct_src2trg += 1
            elif self.print_wrong_matches:
                print("i:", i, "j:", max_idx, "INCORRECT" if i != max_idx else "CORRECT")
                print("Src:", self.source_sentences[i])
                print("Trg:", self.target_sentences[max_idx])
                print("Argmax score:", cos_sims[i][max_idx], "vs. correct score:", cos_sims[i][i])

                results = zip(range(len(cos_sims[i])), cos_sims[i])
                results = sorted(results, key=lambda x: x[1], reverse=True)
                for idx, score in results[0:5]:
                    print("\t", idx, "(Score: %.4f)" % (score), self.target_sentences[idx])

        cos_sims = cos_sims.T
        for i in range(len(cos_sims)):
            max_idx = np.argmax(cos_sims[i])
            if i == max_idx:
                correct_trg2src += 1

        acc_src2trg = correct_src2trg / len(cos_sims)
        acc_trg2src = correct_trg2src / len(cos_sims)

        logger.info("Accuracy src2trg: {:.2f}".format(acc_src2trg * 100))
        logger.info("Accuracy trg2src: {:.2f}".format(acc_trg2src * 100))

        if output_path is not None and self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            output_file_exists = os.path.isfile(csv_path)
            with open(csv_path, newline="", mode="a" if output_file_exists else "w", encoding="utf-8") as f:
                writer = csv.writer(f)
                if not output_file_exists:
                    writer.writerow(self.csv_headers)

                writer.writerow([epoch, steps, acc_src2trg, acc_trg2src])

Rayyyyy's avatar
Rayyyyy committed
178
179
180
181
182
183
184
185
        metrics = {
            "src2trg_accuracy": acc_src2trg,
            "trg2src_accuracy": acc_trg2src,
            "mean_accuracy": (acc_src2trg + acc_trg2src) / 2,
        }
        metrics = self.prefix_name_to_metrics(metrics, self.name)
        self.store_metrics_in_model_card_data(model, metrics)
        return metrics