evaluation_stsbenchmark.py 1.69 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
"""
This examples loads a pre-trained model and evaluates it on the STSbenchmark dataset

Usage:
python evaluation_stsbenchmark.py
OR
python evaluation_stsbenchmark.py model_name
"""

import logging
Rayyyyy's avatar
Rayyyyy committed
11
import os
Rayyyyy's avatar
Rayyyyy committed
12
import sys
Rayyyyy's avatar
Rayyyyy committed
13

Rayyyyy's avatar
Rayyyyy committed
14
import torch
Rayyyyy's avatar
Rayyyyy committed
15
16
17
18
19

from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.similarity_functions import SimilarityFunction
Rayyyyy's avatar
Rayyyyy committed
20
21
22
23
24
25

script_folder_path = os.path.dirname(os.path.realpath(__file__))

# Limit torch to 4 threads
torch.set_num_threads(4)

Rayyyyy's avatar
Rayyyyy committed
26
27
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
Rayyyyy's avatar
Rayyyyy committed
28
29
30
31
32
33
34

model_name = sys.argv[1] if len(sys.argv) > 1 else "stsb-distilroberta-base-v2"

# Load a named sentence model (based on BERT). This will download the model from our server.
# Alternatively, you can also pass a filepath to SentenceTransformer()
model = SentenceTransformer(model_name)

Rayyyyy's avatar
Rayyyyy committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
stsb_eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
dev_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=stsb_eval_dataset["sentence1"],
    sentences2=stsb_eval_dataset["sentence2"],
    scores=stsb_eval_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-dev",
)
model.evaluate(dev_evaluator)

test_dataset = load_dataset("sentence-transformers/stsb", split="test")
test_evaluator = EmbeddingSimilarityEvaluator(
    sentences1=test_dataset["sentence1"],
    sentences2=test_dataset["sentence2"],
    scores=test_dataset["score"],
    main_similarity=SimilarityFunction.COSINE,
    name="sts-test",
)
model.evaluate(test_evaluator)