test_embedding.py 13 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import base64

import numpy as np
7
8
import openai
import pytest
9
import pytest_asyncio
10
11
import requests

12
from vllm.entrypoints.openai.protocol import EmbeddingResponse
13
from vllm.transformers_utils.tokenizer import get_tokenizer
14

15
16
from ...models.language.pooling.embed_utils import (
    run_embedding_correctness_test)
17
from ...models.utils import check_embeddings_close
18
from ...utils import RemoteOpenAIServer
19

20
MODEL_NAME = "intfloat/multilingual-e5-small"
21
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}"""  # noqa: E501
22
DTYPE = "bfloat16"
23
24


25
26
27
28
29
30
31
32
@pytest.fixture(autouse=True)
def v1(run_with_both_engines):
    # Simple autouse wrapper to run both engines for each test
    # This can be promoted up to conftest.py to run for every
    # test in a package
    pass


33
@pytest.fixture(scope="module")
34
def server():
35
    args = [
36
37
        "--task",
        "embed",
38
39
        # use half precision for speed and memory savings in CI environment
        "--dtype",
40
        DTYPE,
41
42
        "--enforce-eager",
        "--max-model-len",
43
        "512",
44
45
        "--chat-template",
        DUMMY_CHAT_TEMPLATE,
46
47
    ]

48
    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
49
        yield remote_server
50
51


52
@pytest_asyncio.fixture
53
54
async def client(server):
    async with server.get_async_client() as async_client:
55
        yield async_client
56
57


58
59
60
61
62
63
64
@pytest.fixture(scope="module")
def hf_model(hf_runner):
    with hf_runner(MODEL_NAME, dtype=DTYPE,
                   is_sentence_transformer=True) as hf_model:
        yield hf_model


65
@pytest.mark.asyncio
66
@pytest.mark.parametrize("model_name", [MODEL_NAME])
67
68
async def test_single_embedding(hf_model, client: openai.AsyncOpenAI,
                                model_name: str):
69
70
71
72
73
    input_texts = [
        "The chef prepared a delicious meal.",
    ]

    # test single embedding
74
    embedding_response = await client.embeddings.create(
75
76
77
78
        model=model_name,
        input=input_texts,
        encoding_format="float",
    )
79
80
81
    embeddings = EmbeddingResponse.model_validate(
        embedding_response.model_dump(mode="json"))

82
83
    assert embeddings.id is not None
    assert len(embeddings.data) == 1
84
    assert len(embeddings.data[0].embedding) == 384
85
    assert embeddings.usage.completion_tokens == 0
86
87
    assert embeddings.usage.prompt_tokens == 11
    assert embeddings.usage.total_tokens == 11
88

89
    vllm_outputs = [d.embedding for d in embeddings.data]
90
    run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
91

92
93
    # test using token IDs
    input_tokens = [1, 1, 1, 1, 1]
94
    embedding_response = await client.embeddings.create(
95
96
97
98
        model=model_name,
        input=input_tokens,
        encoding_format="float",
    )
99
100
101
    embeddings = EmbeddingResponse.model_validate(
        embedding_response.model_dump(mode="json"))

102
103
    assert embeddings.id is not None
    assert len(embeddings.data) == 1
104
    assert len(embeddings.data[0].embedding) == 384
105
106
107
108
109
110
    assert embeddings.usage.completion_tokens == 0
    assert embeddings.usage.prompt_tokens == 5
    assert embeddings.usage.total_tokens == 5


@pytest.mark.asyncio
111
@pytest.mark.parametrize("model_name", [MODEL_NAME])
112
113
async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI,
                               model_name: str):
114
    # test list[str]
115
116
117
118
    input_texts = [
        "The cat sat on the mat.", "A feline was resting on a rug.",
        "Stars twinkle brightly in the night sky."
    ]
119
    embedding_response = await client.embeddings.create(
120
121
122
123
        model=model_name,
        input=input_texts,
        encoding_format="float",
    )
124
125
126
    embeddings = EmbeddingResponse.model_validate(
        embedding_response.model_dump(mode="json"))

127
128
    assert embeddings.id is not None
    assert len(embeddings.data) == 3
129
    assert len(embeddings.data[0].embedding) == 384
130
    assert embeddings.usage.completion_tokens == 0
131
132
    assert embeddings.usage.prompt_tokens == 33
    assert embeddings.usage.total_tokens == 33
133

134
    vllm_outputs = [d.embedding for d in embeddings.data]
135
    run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
136

137
    # test list[list[int]]
138
139
    input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
                    [25, 32, 64, 77]]
140
    embedding_response = await client.embeddings.create(
141
142
143
144
        model=model_name,
        input=input_tokens,
        encoding_format="float",
    )
145
146
147
    embeddings = EmbeddingResponse.model_validate(
        embedding_response.model_dump(mode="json"))

148
149
    assert embeddings.id is not None
    assert len(embeddings.data) == 4
150
    assert len(embeddings.data[0].embedding) == 384
151
152
153
    assert embeddings.usage.completion_tokens == 0
    assert embeddings.usage.prompt_tokens == 17
    assert embeddings.usage.total_tokens == 17
154
155
156


@pytest.mark.asyncio
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_conversation_embedding(server: RemoteOpenAIServer,
                                      client: openai.AsyncOpenAI,
                                      model_name: str):
    messages = [{
        "role": "user",
        "content": "The cat sat on the mat.",
    }, {
        "role": "assistant",
        "content": "A feline was resting on a rug.",
    }, {
        "role": "user",
        "content": "Stars twinkle brightly in the night sky.",
    }]

172
173
174
175
176
177
178
179
    chat_response = requests.post(
        server.url_for("v1/embeddings"),
        json={
            "model": model_name,
            "messages": messages,
            "encoding_format": "float",
        },
    )
180
    chat_response.raise_for_status()
181
    chat_embeddings = EmbeddingResponse.model_validate(chat_response.json())
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197

    tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
    prompt = tokenizer.apply_chat_template(
        messages,
        chat_template=DUMMY_CHAT_TEMPLATE,
        add_generation_prompt=True,
        continue_final_message=False,
        tokenize=False,
    )
    completion_response = await client.embeddings.create(
        model=model_name,
        input=prompt,
        encoding_format="float",
        # To be consistent with chat
        extra_body={"add_special_tokens": False},
    )
198
199
    completion_embeddings = EmbeddingResponse.model_validate(
        completion_response.model_dump(mode="json"))
200

201
202
203
204
205
206
    assert chat_embeddings.id is not None
    assert completion_embeddings.id is not None
    assert chat_embeddings.created <= completion_embeddings.created
    assert chat_embeddings.model_dump(
        exclude={"id", "created"}) == (completion_embeddings.model_dump(
            exclude={"id", "created"}))
207
208
209
210


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
211
async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
212
213
214
215
216
217
                                      model_name: str):
    input_texts = [
        "Hello my name is",
        "The best thing about vLLM is that it supports many different models"
    ]

218
219
220
    responses_float = await client.embeddings.create(input=input_texts,
                                                     model=model_name,
                                                     encoding_format="float")
221
    float_data = [d.embedding for d in responses_float.data]
222
    run_embedding_correctness_test(hf_model, input_texts, float_data)
223

224
225
226
    responses_base64 = await client.embeddings.create(input=input_texts,
                                                      model=model_name,
                                                      encoding_format="base64")
227
    base64_data = []
228
    for data in responses_base64.data:
229
        base64_data.append(
230
            np.frombuffer(base64.b64decode(data.embedding),
231
                          dtype="float32").tolist())
232

233
    run_embedding_correctness_test(hf_model, input_texts, base64_data)
234
235

    # Default response is float32 decoded from base64 by OpenAI Client
236
237
    responses_default = await client.embeddings.create(input=input_texts,
                                                       model=model_name)
238
    default_data = [d.embedding for d in responses_default.data]
239
    run_embedding_correctness_test(hf_model, input_texts, default_data)
240
241
242


@pytest.mark.asyncio
243
244
245
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding_truncation(client: openai.AsyncOpenAI,
                                           model_name: str):
246
247
248
249
250
    input_texts = [
        "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
    ]

    # test single embedding
251
    embedding_response = await client.embeddings.create(
252
253
254
        model=model_name,
        input=input_texts,
        extra_body={"truncate_prompt_tokens": 10})
255
256
257
    embeddings = EmbeddingResponse.model_validate(
        embedding_response.model_dump(mode="json"))

258
259
    assert embeddings.id is not None
    assert len(embeddings.data) == 1
260
    assert len(embeddings.data[0].embedding) == 384
261
262
263
264
265
266
267
268
    assert embeddings.usage.completion_tokens == 0
    assert embeddings.usage.prompt_tokens == 10
    assert embeddings.usage.total_tokens == 10

    input_tokens = [
        1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
        9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
    ]
269
    embedding_response = await client.embeddings.create(
270
271
272
        model=model_name,
        input=input_tokens,
        extra_body={"truncate_prompt_tokens": 10})
273
274
    embeddings = EmbeddingResponse.model_validate(
        embedding_response.model_dump(mode="json"))
275
276
277

    assert embeddings.id is not None
    assert len(embeddings.data) == 1
278
    assert len(embeddings.data[0].embedding) == 384
279
280
281
282
283
284
    assert embeddings.usage.completion_tokens == 0
    assert embeddings.usage.prompt_tokens == 10
    assert embeddings.usage.total_tokens == 10


@pytest.mark.asyncio
285
286
287
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
                                                   model_name: str):
288
289
290
291
292
    input_texts = [
        "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
    ]

    with pytest.raises(openai.BadRequestError):
293
        response = await client.embeddings.create(
294
295
296
            model=model_name,
            input=input_texts,
            extra_body={"truncate_prompt_tokens": 8193})
297
        assert "error" in response.object
298
        assert "truncate_prompt_tokens value is greater than max_model_len. "\
299
               "Please, select a smaller truncation size." in response.message
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324


@pytest.mark.asyncio
async def test_invocations(server: RemoteOpenAIServer,
                           client: openai.AsyncOpenAI):
    input_texts = [
        "The chef prepared a delicious meal.",
    ]

    request_args = {
        "model": MODEL_NAME,
        "input": input_texts,
        "encoding_format": "float",
    }

    completion_response = await client.embeddings.create(**request_args)

    invocation_response = requests.post(server.url_for("invocations"),
                                        json=request_args)
    invocation_response.raise_for_status()

    completion_output = completion_response.model_dump()
    invocation_output = invocation_response.json()

    assert completion_output.keys() == invocation_output.keys()
325
326
327
328
329
330
331
    for completion_data, invocation_data in zip(completion_output["data"],
                                                invocation_output["data"]):
        assert completion_data.keys() == invocation_data.keys()
        check_embeddings_close(embeddings_0_lst=[completion_data["embedding"]],
                               embeddings_1_lst=[invocation_data["embedding"]],
                               name_0="completion",
                               name_1="invocation")
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364


@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
    messages = [{
        "role": "user",
        "content": "The cat sat on the mat.",
    }, {
        "role": "assistant",
        "content": "A feline was resting on a rug.",
    }, {
        "role": "user",
        "content": "Stars twinkle brightly in the night sky.",
    }]

    request_args = {
        "model": MODEL_NAME,
        "messages": messages,
        "encoding_format": "float",
    }

    chat_response = requests.post(server.url_for("v1/embeddings"),
                                  json=request_args)
    chat_response.raise_for_status()

    invocation_response = requests.post(server.url_for("invocations"),
                                        json=request_args)
    invocation_response.raise_for_status()

    chat_output = chat_response.json()
    invocation_output = invocation_response.json()

    assert chat_output.keys() == invocation_output.keys()
365
366
367
368
369
370
371
    for chat_data, invocation_data in zip(chat_output["data"],
                                          invocation_output["data"]):
        assert chat_data.keys() == invocation_data.keys()
        check_embeddings_close(embeddings_0_lst=[chat_data["embedding"]],
                               embeddings_1_lst=[invocation_data["embedding"]],
                               name_0="chat",
                               name_1="invocation")