test_rerank.py 7.14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pytest
import requests
6
7
import torch
import torch.nn.functional as F
8

9
from tests.utils import RemoteOpenAIServer
10
from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse
11
12
13
14
15
16
from vllm.platforms import current_platform

if current_platform.is_rocm():
    pytest.skip(
        "Encoder self-attention is not implemented on ROCm.", allow_module_level=True
    )
17
18

MODEL_NAME = "BAAI/bge-reranker-base"
19
DTYPE = "bfloat16"
20
21
22
23


@pytest.fixture(scope="module")
def server():
24
    args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
25
26
27
28
29
30
31
32
33

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
    query = "What is the capital of France?"
    documents = [
34
35
        "The capital of Brazil is Brasilia.",
        "The capital of France is Paris.",
36
37
    ]

38
39
40
41
42
43
44
45
    rerank_response = requests.post(
        server.url_for("rerank"),
        json={
            "model": model_name,
            "query": query,
            "documents": documents,
        },
    )
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
    rerank_response.raise_for_status()
    rerank = RerankResponse.model_validate(rerank_response.json())

    assert rerank.id is not None
    assert rerank.results is not None
    assert len(rerank.results) == 2
    assert rerank.results[0].relevance_score >= 0.9
    assert rerank.results[1].relevance_score <= 0.01


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_top_n(server: RemoteOpenAIServer, model_name: str):
    query = "What is the capital of France?"
    documents = [
        "The capital of Brazil is Brasilia.",
61
62
        "The capital of France is Paris.",
        "Cross-encoder models are neat",
63
64
    ]

65
66
67
68
    rerank_response = requests.post(
        server.url_for("rerank"),
        json={"model": model_name, "query": query, "documents": documents, "top_n": 2},
    )
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    rerank_response.raise_for_status()
    rerank = RerankResponse.model_validate(rerank_response.json())

    assert rerank.id is not None
    assert rerank.results is not None
    assert len(rerank.results) == 2
    assert rerank.results[0].relevance_score >= 0.9
    assert rerank.results[1].relevance_score <= 0.01


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
    query = "What is the capital of France?" * 100
    documents = [
83
84
        "The capital of Brazil is Brasilia.",
        "The capital of France is Paris.",
85
86
    ]

87
88
89
90
    rerank_response = requests.post(
        server.url_for("rerank"),
        json={"model": model_name, "query": query, "documents": documents},
    )
91
92
    assert rerank_response.status_code == 400
    # Assert just a small fragments of the response
93
    assert "Please reduce the length of the input." in rerank_response.text
94
95
96
97
98


def test_invocations(server: RemoteOpenAIServer):
    query = "What is the capital of France?"
    documents = [
99
100
        "The capital of Brazil is Brasilia.",
        "The capital of France is Paris.",
101
102
103
104
105
106
107
108
    ]

    request_args = {
        "model": MODEL_NAME,
        "query": query,
        "documents": documents,
    }

109
    rerank_response = requests.post(server.url_for("rerank"), json=request_args)
110
111
    rerank_response.raise_for_status()

112
113
114
    invocation_response = requests.post(
        server.url_for("invocations"), json=request_args
    )
115
116
117
118
119
120
    invocation_response.raise_for_status()

    rerank_output = rerank_response.json()
    invocation_output = invocation_response.json()

    assert rerank_output.keys() == invocation_output.keys()
121
122
123
    for rerank_result, invocations_result in zip(
        rerank_output["results"], invocation_output["results"]
    ):
124
125
        assert rerank_result.keys() == invocations_result.keys()
        assert rerank_result["relevance_score"] == pytest.approx(
126
127
            invocations_result["relevance_score"], rel=0.05
        )
128
129
        # TODO: reset this tolerance to 0.01 once we find
        # an alternative to flash_attn with bfloat16
130
131
132
133


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
134
135
async def test_use_activation(server: RemoteOpenAIServer, model_name: str):
    async def get_outputs(use_activation):
136
137
138
        query = "What is the capital of France?"
        documents = [
            "The capital of Brazil is Brasilia.",
139
            "The capital of France is Paris.",
140
141
        ]

142
143
144
145
146
147
        response = requests.post(
            server.url_for("rerank"),
            json={
                "model": model_name,
                "query": query,
                "documents": documents,
148
                "use_activation": use_activation,
149
150
            },
        )
151
152
        outputs = response.json()

153
        return torch.tensor([x["relevance_score"] for x in outputs["results"]])
154

155
156
157
    default = await get_outputs(use_activation=None)
    w_activation = await get_outputs(use_activation=True)
    wo_activation = await get_outputs(use_activation=False)
158

159
160
161
162
163
164
165
166
167
    assert torch.allclose(default, w_activation, atol=1e-2), (
        "Default should use activation."
    )
    assert not torch.allclose(w_activation, wo_activation, atol=1e-2), (
        "wo_activation should not use activation."
    )
    assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), (
        "w_activation should be close to activation(wo_activation)."
    )
168
169
170
171


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
async def test_pooling_classify(server: RemoteOpenAIServer, model_name: str):
    input_text = "This product was excellent and exceeded my expectations"
    response = requests.post(
        server.url_for("pooling"),
        json={
            "model": model_name,
            "input": input_text,
            "encoding_format": "float",
            "task": "classify",
        },
    )
    poolings = PoolingResponse.model_validate(response.json())
    assert len(poolings.data) == 1
    assert len(poolings.data[0].data) == 1


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_classify(server: RemoteOpenAIServer, model_name: str):
191
192
193
194
195
196
197
198
199
200
201
202
    input_text = ["The chef prepared a delicious meal."]

    response = requests.post(
        server.url_for("pooling"),
        json={"model": model_name, "input": input_text, "encoding_format": "float"},
    )

    poolings = PoolingResponse.model_validate(response.json())

    assert len(poolings.data) == 1
    assert len(poolings.data[0].data) == 11
    assert len(poolings.data[0].data[0]) == 1
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["embed", "token_embed", "plugin"])
async def test_pooling_not_supported(
    server: RemoteOpenAIServer, model_name: str, task: str
):
    response = requests.post(
        server.url_for("pooling"),
        json={
            "model": model_name,
            "input": "test",
            "encoding_format": "float",
            "task": task,
        },
    )
    assert response.json()["error"]["type"] == "BadRequestError"
    assert response.json()["error"]["message"].startswith(
        f"Task {task} is not supported"
    )