MSEEvaluatorFromDataFrame.py 5.35 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
import csv
Rayyyyy's avatar
Rayyyyy committed
2
3
import logging
import os
Rayyyyy's avatar
Rayyyyy committed
4
5
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
Rayyyyy's avatar
Rayyyyy committed
6

Rayyyyy's avatar
Rayyyyy committed
7
8
9
10
11
12
import numpy as np

from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator

if TYPE_CHECKING:
    from sentence_transformers.SentenceTransformer import SentenceTransformer
Rayyyyy's avatar
Rayyyyy committed
13
14
15
16
17
18
19
20

logger = logging.getLogger(__name__)


class MSEEvaluatorFromDataFrame(SentenceEvaluator):
    """
    Computes the mean squared error (x100) between the computed sentence embedding and some target sentence embedding.

Rayyyyy's avatar
Rayyyyy committed
21
22
23
    Args:
        dataframe (List[Dict[str, str]]): It must have the following format. Rows contains different, parallel sentences.
            Columns are the respective language codes::
Rayyyyy's avatar
Rayyyyy committed
24

Rayyyyy's avatar
Rayyyyy committed
25
            [{'en': 'My sentence in English', 'es': 'Oración en español', 'fr': 'Phrase en français'...},
Rayyyyy's avatar
Rayyyyy committed
26
             {'en': 'My second sentence', ...}]
Rayyyyy's avatar
Rayyyyy committed
27
28
29
30
31
32
33
34
35
36
        teacher_model (SentenceTransformer): The teacher model used to compute the sentence embeddings.
        combinations (List[Tuple[str, str]]): Must be of the format ``[('en', 'es'), ('en', 'fr'), ...]``.
            First entry in a tuple is the source language. The sentence in the respective language will be fetched from
            the dataframe and passed to the teacher model. Second entry in a tuple the the target language. Sentence
            will be fetched from the dataframe and passed to the student model
        batch_size (int, optional): The batch size to compute sentence embeddings. Defaults to 8.
        name (str, optional): The name of the evaluator. Defaults to "".
        write_csv (bool, optional): Whether to write the results to a CSV file. Defaults to True.
        truncate_dim (Optional[int], optional): The dimension to truncate sentence embeddings to. If None, uses the model's
            current truncation dimension. Defaults to None.
Rayyyyy's avatar
Rayyyyy committed
37
38
39
40
41
    """

    def __init__(
        self,
        dataframe: List[Dict[str, str]],
Rayyyyy's avatar
Rayyyyy committed
42
        teacher_model: "SentenceTransformer",
Rayyyyy's avatar
Rayyyyy committed
43
44
        combinations: List[Tuple[str, str]],
        batch_size: int = 8,
Rayyyyy's avatar
Rayyyyy committed
45
        name: str = "",
Rayyyyy's avatar
Rayyyyy committed
46
        write_csv: bool = True,
Rayyyyy's avatar
Rayyyyy committed
47
        truncate_dim: Optional[int] = None,
Rayyyyy's avatar
Rayyyyy committed
48
    ):
Rayyyyy's avatar
Rayyyyy committed
49
        super().__init__()
Rayyyyy's avatar
Rayyyyy committed
50
51
52
53
54
55
56
57
58
        self.combinations = combinations
        self.name = name
        self.batch_size = batch_size

        if name:
            name = "_" + name

        self.csv_file = "mse_evaluation" + name + "_results.csv"
        self.csv_headers = ["epoch", "steps"]
Rayyyyy's avatar
Rayyyyy committed
59
        self.primary_metric = "negative_mse"
Rayyyyy's avatar
Rayyyyy committed
60
        self.write_csv = write_csv
Rayyyyy's avatar
Rayyyyy committed
61
        self.truncate_dim = truncate_dim
Rayyyyy's avatar
Rayyyyy committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        self.data = {}

        logger.info("Compute teacher embeddings")
        all_source_sentences = set()
        for src_lang, trg_lang in self.combinations:
            src_sentences = []
            trg_sentences = []

            for row in dataframe:
                if row[src_lang].strip() != "" and row[trg_lang].strip() != "":
                    all_source_sentences.add(row[src_lang])
                    src_sentences.append(row[src_lang])
                    trg_sentences.append(row[trg_lang])

            self.data[(src_lang, trg_lang)] = (src_sentences, trg_sentences)
            self.csv_headers.append("{}-{}".format(src_lang, trg_lang))

        all_source_sentences = list(all_source_sentences)
Rayyyyy's avatar
Rayyyyy committed
80
81
82
83
        with nullcontext() if self.truncate_dim is None else teacher_model.truncate_sentence_embeddings(
            self.truncate_dim
        ):
            all_src_embeddings = teacher_model.encode(all_source_sentences, batch_size=self.batch_size)
Rayyyyy's avatar
Rayyyyy committed
84
85
        self.teacher_embeddings = {sent: emb for sent, emb in zip(all_source_sentences, all_src_embeddings)}

Rayyyyy's avatar
Rayyyyy committed
86
87
88
    def __call__(
        self, model: "SentenceTransformer", output_path: str = None, epoch: int = -1, steps: int = -1
    ) -> Dict[str, float]:
Rayyyyy's avatar
Rayyyyy committed
89
90
91
92
93
94
95
        model.eval()

        mse_scores = []
        for src_lang, trg_lang in self.combinations:
            src_sentences, trg_sentences = self.data[(src_lang, trg_lang)]

            src_embeddings = np.asarray([self.teacher_embeddings[sent] for sent in src_sentences])
Rayyyyy's avatar
Rayyyyy committed
96
97
            with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
                trg_embeddings = np.asarray(model.encode(trg_sentences, batch_size=self.batch_size))
Rayyyyy's avatar
Rayyyyy committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

            mse = ((src_embeddings - trg_embeddings) ** 2).mean()
            mse *= 100
            mse_scores.append(mse)

            logger.info("MSE evaluation on {} dataset - {}-{}:".format(self.name, src_lang, trg_lang))
            logger.info("MSE (*100):\t{:4f}".format(mse))

        if output_path is not None and self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            output_file_exists = os.path.isfile(csv_path)
            with open(csv_path, newline="", mode="a" if output_file_exists else "w", encoding="utf-8") as f:
                writer = csv.writer(f)
                if not output_file_exists:
                    writer.writerow(self.csv_headers)

                writer.writerow([epoch, steps] + mse_scores)

Rayyyyy's avatar
Rayyyyy committed
116
117
118
119
120
121
122
123
124
        # Return negative score as SentenceTransformers maximizes the performance
        metrics = {"negative_mse": -np.mean(mse_scores).item()}
        metrics = self.prefix_name_to_metrics(metrics, self.name)
        self.store_metrics_in_model_card_data(model, metrics)
        return metrics

    @property
    def description(self) -> str:
        return "Knowledge Distillation"