conftest.py 1.94 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
import os
import platform
import tempfile
Rayyyyy's avatar
Rayyyyy committed
4

Rayyyyy's avatar
Rayyyyy committed
5
6
import pytest

Rayyyyy's avatar
Rayyyyy committed
7
8
9
10
11
12
from sentence_transformers import CrossEncoder, SentenceTransformer
from sentence_transformers.models import Pooling, Transformer
from sentence_transformers.util import is_datasets_available

if is_datasets_available():
    from datasets import DatasetDict, load_dataset
Rayyyyy's avatar
Rayyyyy committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


@pytest.fixture()
def stsb_bert_tiny_model() -> SentenceTransformer:
    return SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")


@pytest.fixture(scope="session")
def stsb_bert_tiny_model_reused() -> SentenceTransformer:
    return SentenceTransformer("sentence-transformers-testing/stsb-bert-tiny-safetensors")


@pytest.fixture()
def paraphrase_distilroberta_base_v1_model() -> SentenceTransformer:
    return SentenceTransformer("paraphrase-distilroberta-base-v1")


@pytest.fixture()
def distilroberta_base_ce_model() -> CrossEncoder:
    return CrossEncoder("distilroberta-base", num_labels=1)


@pytest.fixture()
def clip_vit_b_32_model() -> SentenceTransformer:
    return SentenceTransformer("clip-ViT-B-32")


@pytest.fixture()
def distilbert_base_uncased_model() -> SentenceTransformer:
    word_embedding_model = Transformer("distilbert-base-uncased")
    pooling_model = Pooling(word_embedding_model.get_word_embedding_dimension())
    model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
    return model


Rayyyyy's avatar
Rayyyyy committed
48
49
50
51
52
@pytest.fixture(scope="session")
def stsb_dataset_dict() -> "DatasetDict":
    return load_dataset("mteb/stsbenchmark-sts")


Rayyyyy's avatar
Rayyyyy committed
53
54
55
56
57
58
59
60
61
62
63
64
65
@pytest.fixture()
def cache_dir():
    """
    In the CI environment, we use a temporary directory as `cache_dir`
    to avoid keeping the downloaded models on disk after the test.

    This is only required for Ubuntu, as we otherwise have disk space issues there.
    """
    if os.environ.get("CI", None) and platform.system() == "Linux":
        with tempfile.TemporaryDirectory() as tmp_dir:
            yield tmp_dir
    else:
        yield None