test_jina.py 4.16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from functools import partial

5
6
import pytest

7
import vllm.envs as envs
8
from vllm import PoolingParams
9

10
11
from ...utils import EmbedModelInfo, RerankModelInfo
from .embed_utils import (check_embeddings_close,
12
                          correctness_test_embed_models, matryoshka_fy)
13
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
14

15
EMBEDDING_MODELS = [
16
17
    EmbedModelInfo("jinaai/jina-embeddings-v3",
                   architecture="XLMRobertaModel",
18
                   is_matryoshka=True)
19
20
]

21
RERANK_MODELS = [
22
23
    RerankModelInfo("jinaai/jina-reranker-v2-base-multilingual",
                    architecture="XLMRobertaForSequenceClassification")
24
]
25
26


27
28
29
30
31
32
33
34
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
    # Simple autouse wrapper to run both engines for each test
    # This can be promoted up to conftest.py to run for every
    # test in a package
    pass


35
36
37
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_mteb(hf_runner, vllm_runner,
                           model_info: EmbedModelInfo) -> None:
38

39
40
    def hf_model_callback(model):
        model.encode = partial(model.encode, task="text-matching")
41

42
43
44
45
    mteb_test_embed_models(hf_runner,
                           vllm_runner,
                           model_info,
                           hf_model_callback=hf_model_callback)
46
47


48
49
50
51
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_models_correctness(hf_runner, vllm_runner,
                                  model_info: EmbedModelInfo,
                                  example_prompts) -> None:
52

53
54
    def hf_model_callback(model):
        model.encode = partial(model.encode, task="text-matching")
55

56
57
58
59
60
    correctness_test_embed_models(hf_runner,
                                  vllm_runner,
                                  model_info,
                                  example_prompts,
                                  hf_model_callback=hf_model_callback)
61
62


63
64
65
@pytest.mark.parametrize("model_info", RERANK_MODELS)
def test_rerank_models_mteb(hf_runner, vllm_runner,
                            model_info: RerankModelInfo) -> None:
66
67
68
69
    if (model_info.architecture == "XLMRobertaForSequenceClassification"
            and envs.VLLM_USE_V1):
        pytest.skip("Not supported yet")

70
71
72
    mteb_test_rerank_models(hf_runner, vllm_runner, model_info)


73
@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
74
75
76
77
78
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("dimensions", [16, 32])
def test_matryoshka(
    hf_runner,
    vllm_runner,
79
    model_info,
80
81
    dtype: str,
    dimensions: int,
82
    example_prompts,
83
84
    monkeypatch,
) -> None:
85
86
    if not model_info.is_matryoshka:
        pytest.skip("Model is not matryoshka")
87

88
89
    # ST will strip the input texts, see test_embedding.py
    example_prompts = [str(s).strip() for s in example_prompts]
90
91

    with hf_runner(
92
            model_info.name,
93
94
95
96
97
98
            dtype=dtype,
            is_sentence_transformer=True,
    ) as hf_model:
        hf_outputs = hf_model.encode(example_prompts, task="text-matching")
        hf_outputs = matryoshka_fy(hf_outputs, dimensions)

99
    with vllm_runner(model_info.name,
100
                     runner="pooling",
101
                     dtype=dtype,
102
                     max_model_len=None) as vllm_model:
103
        assert vllm_model.llm.llm_engine.model_config.is_matryoshka
104

105
        matryoshka_dimensions = (
106
            vllm_model.llm.llm_engine.model_config.matryoshka_dimensions)
107
108
109
110
        assert matryoshka_dimensions is not None

        if dimensions not in matryoshka_dimensions:
            with pytest.raises(ValueError):
111
                vllm_model.embed(
112
113
114
                    example_prompts,
                    pooling_params=PoolingParams(dimensions=dimensions))
        else:
115
            vllm_outputs = vllm_model.embed(
116
117
118
119
120
121
122
123
124
125
                example_prompts,
                pooling_params=PoolingParams(dimensions=dimensions))

            check_embeddings_close(
                embeddings_0_lst=hf_outputs,
                embeddings_1_lst=vllm_outputs,
                name_0="hf",
                name_1="vllm",
                tol=1e-2,
            )