matryoshka_eval_stsb.py 5.94 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
This script evaluates embedding models truncated at different dimensions on the STS
benchmark.
"""

import argparse
import os
from typing import Dict, List, Optional, Tuple, cast

from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import (
    EmbeddingSimilarityEvaluator,
    SimilarityFunction,
)
from tqdm.auto import tqdm


# Dimension plot
def _grouped_barplot_ratios(
    group_name_to_x_to_y: Dict[str, Dict[int, float]], ax: Optional[plt.Axes] = None
) -> plt.Axes:
    # To save a pandas dependency, do from scratch in matplotlib
    if ax is None:
        ax: plt.Axes = plt.subplots()
    # Sort each by x
    group_name_to_x_to_y = {
        group_name: dict(sorted(x_to_y.items(), key=lambda x: x[0]))
        for group_name, x_to_y in group_name_to_x_to_y.items()
    }
    # Check that all x are the same
    xticks = None
    for group_name, x_to_y in group_name_to_x_to_y.items():
        _xticks = x_to_y.keys()
        if xticks is not None and _xticks != xticks:
            raise ValueError(f"{group_name} has different keys: {_xticks}")
        xticks = _xticks
    xticks = sorted(xticks)

    # Max y will be the denominator in the ratio/fraction
    group_name_to_max_y = {group_name: max(x_to_y.values()) for group_name, x_to_y in group_name_to_x_to_y.items()}
    num_groups = len(group_name_to_x_to_y)
    bar_width = np.diff(xticks).min() / (num_groups + 1)
    # bar_width is the solution to this equation:
    # Say we have the closest x1, x2 st x1 < x2, so x2 - x1 = np.diff(xticks).min().
    # (x2 - (bar_width * num_groups/2)) - (x1 + (bar_width * num_groups/2)) = bar_width
    xs = np.array(
        [
            np.linspace(
                start=xtick - ((bar_width / 2) * (num_groups - 1)),
                stop=xtick + ((bar_width / 2) * (num_groups - 1)),
                num=num_groups,
            )
            for xtick in xticks
        ]
    ).T
    # xs are the center of where the bar goes on the x axis. They have to be manually set
    min_ratio = np.inf
    for i, (group_name, x_to_y) in enumerate(group_name_to_x_to_y.items()):
        max_y = group_name_to_max_y[group_name]
        ys = [y / max_y for y in x_to_y.values()]
        min_ratio = min(min_ratio, min(ys))
        ax.bar(xs[i], ys, bar_width, label=group_name)
    ax.set_xticks(xticks)
    ax.set_xticklabels(xticks)
    ax.grid(linestyle="--")
    ax.set_ylim(min(0.95, min_ratio), 1)
    return ax


def plot_across_dimensions(
    model_name_to_dim_to_score: Dict[str, Dict[int, float]],
    filename: str,
    figsize: Tuple[float, float] = (7, 7),
    title: str = "STSB test score for various embedding dimensions (via truncation),\nwith and without Matryoshka loss",
) -> None:
    # Sort each by key
    model_name_to_dim_to_score = {
        model_name: dict(sorted(dim_to_score.items(), key=lambda x: x[0]))
        for model_name, dim_to_score in model_name_to_dim_to_score.items()
    }
    xticks = sorted(list(model_name_to_dim_to_score.values())[0].keys())

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=figsize)
    ax1 = cast(plt.Axes, ax1)
    ax2 = cast(plt.Axes, ax2)

    # Line plot
    for model_name, dim_to_score in model_name_to_dim_to_score.items():
        ax1.plot(dim_to_score.keys(), dim_to_score.values(), label=model_name)
    ax1.set_xticks(xticks)
    ax1.set_ylabel("Spearman correlation")
    ax1.grid(linestyle="--")
    ax1.legend()

    # Bar plot
    ax2 = _grouped_barplot_ratios(model_name_to_dim_to_score, ax=ax2)
    ax2.set_xlabel("Embedding dimension")
    ax2.set_ylabel("Ratio of maximum performance")

    fig.suptitle(title)
    fig.tight_layout()
    fig.savefig(filename)


if __name__ == "__main__":
    DEFAULT_MODEL_NAMES = [
        "tomaarsen/mpnet-base-nli-matryoshka",  # fit using Matryoshka loss
        "tomaarsen/mpnet-base-nli",  # baseline
    ]
    DEFAULT_DIMENSIONS = [768, 512, 256, 128, 64]

    # Parse args
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("plot_filename", type=str, help="Where to save the plot of results")
    parser.add_argument(
        "--model_names",
        nargs="+",
        default=DEFAULT_MODEL_NAMES,
        help=(
            "List of models which can be loaded using "
            "sentence_transformers.SentenceTransformer(). Default: "
            f"{' '.join(DEFAULT_MODEL_NAMES)}"
        ),
    )
    parser.add_argument(
        "--dimensions",
        nargs="+",
        type=int,
        default=DEFAULT_DIMENSIONS,
        help=(
            "List of dimensions to truncate to and evaluate. Default: "
            f"{' '.join(str(dim) for dim in DEFAULT_DIMENSIONS)}"
        ),
    )

    args = parser.parse_args()
    plot_filename: str = args.plot_filename
    model_names: List[str] = args.model_names
    DIMENSIONS: List[int] = args.dimensions

    # Load STSb
    stsb_test = load_dataset("mteb/stsbenchmark-sts", split="test")
    test_evaluator = EmbeddingSimilarityEvaluator(
        stsb_test["sentence1"],
        stsb_test["sentence2"],
        [score / 5 for score in stsb_test["score"]],
        main_similarity=SimilarityFunction.COSINE,
        name="sts-test",
    )

    # Run test_evaluator
    model_name_to_dim_to_score: Dict[str, Dict[int, float]] = {}
    for model_name in tqdm(model_names, desc="Evaluating models"):
        model = SentenceTransformer(model_name)
        dim_to_score: Dict[int, float] = {}
        for dim in tqdm(DIMENSIONS, desc=f"Evaluating {model_name}"):
            output_path = os.path.join(model_name, f"dim-{dim}")
            os.makedirs(output_path)
            with model.truncate_sentence_embeddings(dim):
                score = test_evaluator(model, output_path=output_path)
            print(f"Saved results to {output_path}")
            dim_to_score[dim] = score
        model_name_to_dim_to_score[model_name] = dim_to_score

    # Save plot
    plot_across_dimensions(model_name_to_dim_to_score, plot_filename)