ParaphraseMiningEvaluator.py 12.1 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
import csv
Rayyyyy's avatar
Rayyyyy committed
2
3
4
import logging
import os
from collections import defaultdict
Rayyyyy's avatar
Rayyyyy committed
5
6
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
Rayyyyy's avatar
Rayyyyy committed
7

Rayyyyy's avatar
Rayyyyy committed
8
9
10
11
12
from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.util import paraphrase_mining

if TYPE_CHECKING:
    from sentence_transformers.SentenceTransformer import SentenceTransformer
Rayyyyy's avatar
Rayyyyy committed
13
14
15
16
17
18
19
20
21

logger = logging.getLogger(__name__)


class ParaphraseMiningEvaluator(SentenceEvaluator):
    """
    Given a large set of sentences, this evaluator performs paraphrase (duplicate) mining and
    identifies the pairs with the highest similarity. It compare the extracted paraphrase pairs
    with a set of gold labels and computes the F1 score.
Rayyyyy's avatar
Rayyyyy committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    Example:
        ::

            from datasets import load_dataset
            from sentence_transformers.SentenceTransformer import SentenceTransformer
            from sentence_transformers.evaluation import ParaphraseMiningEvaluator

            # Load a model
            model = SentenceTransformer('all-mpnet-base-v2')

            # Load the Quora Duplicates Mining dataset
            questions_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "questions", split="dev")
            duplicates_dataset = load_dataset("sentence-transformers/quora-duplicates-mining", "duplicates", split="dev")

            # Create a mapping from qid to question & a list of duplicates (qid1, qid2)
            qid_to_questions = dict(zip(questions_dataset["qid"], questions_dataset["question"]))
            duplicates = list(zip(duplicates_dataset["qid1"], duplicates_dataset["qid2"]))

            # Initialize the paraphrase mining evaluator
            paraphrase_mining_evaluator = ParaphraseMiningEvaluator(
                sentences_map=qid_to_questions,
                duplicates_list=duplicates,
                name="quora-duplicates-dev",
            )
            results = paraphrase_mining_evaluator(model)
            '''
            Paraphrase Mining Evaluation of the model on the quora-duplicates-dev dataset:
            Number of candidate pairs: 250564
            Average Precision: 56.51
            Optimal threshold: 0.8325
            Precision: 52.76
            Recall: 59.19
            F1: 55.79
            '''
            print(paraphrase_mining_evaluator.primary_metric)
            # => "quora-duplicates-dev_average_precision"
            print(results[paraphrase_mining_evaluator.primary_metric])
            # => 0.5650940787776353
Rayyyyy's avatar
Rayyyyy committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    """

    def __init__(
        self,
        sentences_map: Dict[str, str],
        duplicates_list: List[Tuple[str, str]] = None,
        duplicates_dict: Dict[str, Dict[str, bool]] = None,
        add_transitive_closure: bool = False,
        query_chunk_size: int = 5000,
        corpus_chunk_size: int = 100000,
        max_pairs: int = 500000,
        top_k: int = 100,
        show_progress_bar: bool = False,
        batch_size: int = 16,
        name: str = "",
        write_csv: bool = True,
Rayyyyy's avatar
Rayyyyy committed
77
        truncate_dim: Optional[int] = None,
Rayyyyy's avatar
Rayyyyy committed
78
79
    ):
        """
Rayyyyy's avatar
Rayyyyy committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        Initializes the ParaphraseMiningEvaluator.

        Args:
            sentences_map (Dict[str, str]): A dictionary that maps sentence-ids to sentences.
                For example, sentences_map[id] => sentence.
            duplicates_list (List[Tuple[str, str]], optional): A list with id pairs [(id1, id2), (id1, id5)]
                that identifies the duplicates / paraphrases in the sentences_map. Defaults to None.
            duplicates_dict (Dict[str, Dict[str, bool]], optional): A default dictionary mapping [id1][id2]
                to true if id1 and id2 are duplicates. Must be symmetric, i.e., if [id1][id2] => True,
                then [id2][id1] => True. Defaults to None.
            add_transitive_closure (bool, optional): If true, it adds a transitive closure,
                i.e. if dup[a][b] and dup[b][c], then dup[a][c]. Defaults to False.
            query_chunk_size (int, optional): To identify the paraphrases, the cosine-similarity between
                all sentence-pairs will be computed. As this might require a lot of memory, we perform
                a batched computation. query_chunk_size sentences will be compared against up to
                corpus_chunk_size sentences. In the default setting, 5000 sentences will be grouped
                together and compared up-to against 100k other sentences. Defaults to 5000.
            corpus_chunk_size (int, optional): The corpus will be batched, to reduce the memory requirement.
                Defaults to 100000.
            max_pairs (int, optional): We will only extract up to max_pairs potential paraphrase candidates.
                Defaults to 500000.
            top_k (int, optional): For each query, we extract the top_k most similar pairs and add it to a sorted list.
                I.e., for one sentence we cannot find more than top_k paraphrases. Defaults to 100.
            show_progress_bar (bool, optional): Output a progress bar. Defaults to False.
            batch_size (int, optional): Batch size for computing sentence embeddings. Defaults to 16.
            name (str, optional): Name of the experiment. Defaults to "".
            write_csv (bool, optional): Write results to CSV file. Defaults to True.
            truncate_dim (Optional[int], optional): The dimension to truncate sentence embeddings to.
                `None` uses the model's current truncation dimension. Defaults to None.
Rayyyyy's avatar
Rayyyyy committed
109
        """
Rayyyyy's avatar
Rayyyyy committed
110
        super().__init__()
Rayyyyy's avatar
Rayyyyy committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        self.sentences = []
        self.ids = []

        for id, sentence in sentences_map.items():
            self.sentences.append(sentence)
            self.ids.append(id)

        self.name = name
        self.show_progress_bar = show_progress_bar
        self.batch_size = batch_size
        self.query_chunk_size = query_chunk_size
        self.corpus_chunk_size = corpus_chunk_size
        self.max_pairs = max_pairs
        self.top_k = top_k
Rayyyyy's avatar
Rayyyyy committed
125
        self.truncate_dim = truncate_dim
Rayyyyy's avatar
Rayyyyy committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

        self.duplicates = duplicates_dict if duplicates_dict is not None else defaultdict(lambda: defaultdict(bool))
        if duplicates_list is not None:
            for id1, id2 in duplicates_list:
                if id1 in sentences_map and id2 in sentences_map:
                    self.duplicates[id1][id2] = True
                    self.duplicates[id2][id1] = True

        # Add transitive closure
        if add_transitive_closure:
            self.duplicates = self.add_transitive_closure(self.duplicates)

        positive_key_pairs = set()
        for key1 in self.duplicates:
            for key2 in self.duplicates[key1]:
                if (
                    key1 in sentences_map
                    and key2 in sentences_map
                    and (self.duplicates[key1][key2] or self.duplicates[key2][key1])
                ):
                    positive_key_pairs.add(tuple(sorted([key1, key2])))

        self.total_num_duplicates = len(positive_key_pairs)

        if name:
            name = "_" + name

        self.csv_file: str = "paraphrase_mining_evaluation" + name + "_results.csv"
        self.csv_headers = ["epoch", "steps", "precision", "recall", "f1", "threshold", "average_precision"]
        self.write_csv = write_csv
Rayyyyy's avatar
Rayyyyy committed
156
        self.primary_metric = "average_precision"
Rayyyyy's avatar
Rayyyyy committed
157

Rayyyyy's avatar
Rayyyyy committed
158
159
160
    def __call__(
        self, model: "SentenceTransformer", output_path: str = None, epoch: int = -1, steps: int = -1
    ) -> Dict[str, float]:
Rayyyyy's avatar
Rayyyyy committed
161
        if epoch != -1:
Rayyyyy's avatar
Rayyyyy committed
162
163
164
165
            if steps == -1:
                out_txt = f" after epoch {epoch}"
            else:
                out_txt = f" in epoch {epoch} after {steps} steps"
Rayyyyy's avatar
Rayyyyy committed
166
        else:
Rayyyyy's avatar
Rayyyyy committed
167
168
169
            out_txt = ""
        if self.truncate_dim is not None:
            out_txt += f" (truncated to {self.truncate_dim})"
Rayyyyy's avatar
Rayyyyy committed
170

Rayyyyy's avatar
Rayyyyy committed
171
        logger.info(f"Paraphrase Mining Evaluation of the model on the {self.name} dataset{out_txt}:")
Rayyyyy's avatar
Rayyyyy committed
172
173

        # Compute embedding for the sentences
Rayyyyy's avatar
Rayyyyy committed
174
175
176
177
178
179
180
181
182
183
184
        with nullcontext() if self.truncate_dim is None else model.truncate_sentence_embeddings(self.truncate_dim):
            pairs_list = paraphrase_mining(
                model,
                self.sentences,
                self.show_progress_bar,
                self.batch_size,
                self.query_chunk_size,
                self.corpus_chunk_size,
                self.max_pairs,
                self.top_k,
            )
Rayyyyy's avatar
Rayyyyy committed
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233

        logger.info("Number of candidate pairs: " + str(len(pairs_list)))

        # Compute F1 score and Average Precision
        n_extract = n_correct = 0
        threshold = 0
        best_f1 = best_recall = best_precision = 0

        average_precision = 0

        for idx in range(len(pairs_list)):
            score, i, j = pairs_list[idx]
            id1 = self.ids[i]
            id2 = self.ids[j]

            # Compute optimal threshold and F1-score
            n_extract += 1
            if self.duplicates[id1][id2] or self.duplicates[id2][id1]:
                n_correct += 1
                precision = n_correct / n_extract
                recall = n_correct / self.total_num_duplicates
                f1 = 2 * precision * recall / (precision + recall)
                average_precision += precision
                if f1 > best_f1:
                    best_f1 = f1
                    best_precision = precision
                    best_recall = recall
                    threshold = (pairs_list[idx][0] + pairs_list[min(idx + 1, len(pairs_list) - 1)][0]) / 2

        average_precision = average_precision / self.total_num_duplicates

        logger.info("Average Precision: {:.2f}".format(average_precision * 100))
        logger.info("Optimal threshold: {:.4f}".format(threshold))
        logger.info("Precision: {:.2f}".format(best_precision * 100))
        logger.info("Recall: {:.2f}".format(best_recall * 100))
        logger.info("F1: {:.2f}\n".format(best_f1 * 100))

        if output_path is not None and self.write_csv:
            csv_path = os.path.join(output_path, self.csv_file)
            if not os.path.isfile(csv_path):
                with open(csv_path, newline="", mode="w", encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerow(self.csv_headers)
                    writer.writerow([epoch, steps, best_precision, best_recall, best_f1, threshold, average_precision])
            else:
                with open(csv_path, newline="", mode="a", encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerow([epoch, steps, best_precision, best_recall, best_f1, threshold, average_precision])

Rayyyyy's avatar
Rayyyyy committed
234
235
236
237
238
239
240
241
242
243
        metrics = {
            "average_precision": average_precision,
            "f1": best_f1,
            "precision": best_precision,
            "recall": best_recall,
            "threshold": threshold,
        }
        metrics = self.prefix_name_to_metrics(metrics, self.name)
        self.store_metrics_in_model_card_data(model, metrics)
        return metrics
Rayyyyy's avatar
Rayyyyy committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270

    @staticmethod
    def add_transitive_closure(graph):
        nodes_visited = set()
        for a in list(graph.keys()):
            if a not in nodes_visited:
                connected_subgraph_nodes = set()
                connected_subgraph_nodes.add(a)

                # Add all nodes in the connected graph
                neighbor_nodes_queue = list(graph[a])
                while len(neighbor_nodes_queue) > 0:
                    node = neighbor_nodes_queue.pop(0)
                    if node not in connected_subgraph_nodes:
                        connected_subgraph_nodes.add(node)
                        neighbor_nodes_queue.extend(graph[node])

                # Ensure transitivity between all nodes in the graph
                connected_subgraph_nodes = list(connected_subgraph_nodes)
                for i in range(len(connected_subgraph_nodes) - 1):
                    for j in range(i + 1, len(connected_subgraph_nodes)):
                        graph[connected_subgraph_nodes[i]][connected_subgraph_nodes[j]] = True
                        graph[connected_subgraph_nodes[j]][connected_subgraph_nodes[i]] = True

                        nodes_visited.add(connected_subgraph_nodes[i])
                        nodes_visited.add(connected_subgraph_nodes[j])
        return graph