test_rerank.py 5.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pytest
import requests
6
7
import torch
import torch.nn.functional as F
8

9
from tests.utils import RemoteOpenAIServer
10
from vllm.entrypoints.openai.protocol import PoolingResponse, RerankResponse
11
12

MODEL_NAME = "BAAI/bge-reranker-base"
13
DTYPE = "bfloat16"
14
15
16
17


@pytest.fixture(scope="module")
def server():
18
    args = ["--enforce-eager", "--max-model-len", "100", "--dtype", DTYPE]
19
20
21
22
23
24
25
26
27

    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
        yield remote_server


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_texts(server: RemoteOpenAIServer, model_name: str):
    query = "What is the capital of France?"
    documents = [
28
29
        "The capital of Brazil is Brasilia.",
        "The capital of France is Paris.",
30
31
    ]

32
33
34
35
36
37
38
39
    rerank_response = requests.post(
        server.url_for("rerank"),
        json={
            "model": model_name,
            "query": query,
            "documents": documents,
        },
    )
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
    rerank_response.raise_for_status()
    rerank = RerankResponse.model_validate(rerank_response.json())

    assert rerank.id is not None
    assert rerank.results is not None
    assert len(rerank.results) == 2
    assert rerank.results[0].relevance_score >= 0.9
    assert rerank.results[1].relevance_score <= 0.01


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_top_n(server: RemoteOpenAIServer, model_name: str):
    query = "What is the capital of France?"
    documents = [
        "The capital of Brazil is Brasilia.",
55
56
        "The capital of France is Paris.",
        "Cross-encoder models are neat",
57
58
    ]

59
60
61
62
    rerank_response = requests.post(
        server.url_for("rerank"),
        json={"model": model_name, "query": query, "documents": documents, "top_n": 2},
    )
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    rerank_response.raise_for_status()
    rerank = RerankResponse.model_validate(rerank_response.json())

    assert rerank.id is not None
    assert rerank.results is not None
    assert len(rerank.results) == 2
    assert rerank.results[0].relevance_score >= 0.9
    assert rerank.results[1].relevance_score <= 0.01


@pytest.mark.parametrize("model_name", [MODEL_NAME])
def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
    query = "What is the capital of France?" * 100
    documents = [
77
78
        "The capital of Brazil is Brasilia.",
        "The capital of France is Paris.",
79
80
    ]

81
82
83
84
    rerank_response = requests.post(
        server.url_for("rerank"),
        json={"model": model_name, "query": query, "documents": documents},
    )
85
86
    assert rerank_response.status_code == 400
    # Assert just a small fragments of the response
87
    assert "Please reduce the length of the input." in rerank_response.text
88
89
90
91
92


def test_invocations(server: RemoteOpenAIServer):
    query = "What is the capital of France?"
    documents = [
93
94
        "The capital of Brazil is Brasilia.",
        "The capital of France is Paris.",
95
96
97
98
99
100
101
102
    ]

    request_args = {
        "model": MODEL_NAME,
        "query": query,
        "documents": documents,
    }

103
    rerank_response = requests.post(server.url_for("rerank"), json=request_args)
104
105
    rerank_response.raise_for_status()

106
107
108
    invocation_response = requests.post(
        server.url_for("invocations"), json=request_args
    )
109
110
111
112
113
114
    invocation_response.raise_for_status()

    rerank_output = rerank_response.json()
    invocation_output = invocation_response.json()

    assert rerank_output.keys() == invocation_output.keys()
115
116
117
    for rerank_result, invocations_result in zip(
        rerank_output["results"], invocation_output["results"]
    ):
118
119
        assert rerank_result.keys() == invocations_result.keys()
        assert rerank_result["relevance_score"] == pytest.approx(
120
121
            invocations_result["relevance_score"], rel=0.05
        )
122
123
        # TODO: reset this tolerance to 0.01 once we find
        # an alternative to flash_attn with bfloat16
124
125
126
127
128
129
130
131
132


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_activation(server: RemoteOpenAIServer, model_name: str):
    async def get_outputs(activation):
        query = "What is the capital of France?"
        documents = [
            "The capital of Brazil is Brasilia.",
133
            "The capital of France is Paris.",
134
135
        ]

136
137
138
139
140
141
142
143
144
        response = requests.post(
            server.url_for("rerank"),
            json={
                "model": model_name,
                "query": query,
                "documents": documents,
                "activation": activation,
            },
        )
145
146
        outputs = response.json()

147
        return torch.tensor([x["relevance_score"] for x in outputs["results"]])
148
149
150
151
152

    default = await get_outputs(activation=None)
    w_activation = await get_outputs(activation=True)
    wo_activation = await get_outputs(activation=False)

153
154
155
156
157
158
159
160
161
    assert torch.allclose(default, w_activation, atol=1e-2), (
        "Default should use activation."
    )
    assert not torch.allclose(w_activation, wo_activation, atol=1e-2), (
        "wo_activation should not use activation."
    )
    assert torch.allclose(F.sigmoid(wo_activation), w_activation, atol=1e-2), (
        "w_activation should be close to activation(wo_activation)."
    )
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling(server: RemoteOpenAIServer, model_name: str):
    input_text = ["The chef prepared a delicious meal."]

    response = requests.post(
        server.url_for("pooling"),
        json={"model": model_name, "input": input_text, "encoding_format": "float"},
    )

    poolings = PoolingResponse.model_validate(response.json())

    assert len(poolings.data) == 1
    assert len(poolings.data[0].data) == 11
    assert len(poolings.data[0].data[0]) == 1