test_embedding.py 15.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import base64

import numpy as np
7
8
import openai
import pytest
9
import pytest_asyncio
10
import requests
11
12
import torch
import torch.nn.functional as F
13

14
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
15
16
from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer
17
18
19
20
from vllm.entrypoints.openai.protocol import (
    EMBED_DTYPE_TO_TORCH_DTYPE,
    EmbeddingResponse,
)
21
from vllm.transformers_utils.tokenizer import get_tokenizer
22

23
MODEL_NAME = "intfloat/multilingual-e5-small"
24
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}"""  # noqa: E501
25
DTYPE = "bfloat16"
26
27
28


@pytest.fixture(scope="module")
29
def server():
30
    args = [
31
32
        "--runner",
        "pooling",
33
34
        # use half precision for speed and memory savings in CI environment
        "--dtype",
35
        DTYPE,
36
37
        "--enforce-eager",
        "--max-model-len",
38
        "512",
39
40
        "--chat-template",
        DUMMY_CHAT_TEMPLATE,
41
42
    ]

43
    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
44
        yield remote_server
45
46


47
@pytest_asyncio.fixture
48
49
async def client(server):
    async with server.get_async_client() as async_client:
50
        yield async_client
51
52


53
54
@pytest.fixture(scope="module")
def hf_model(hf_runner):
55
    with hf_runner(MODEL_NAME, dtype=DTYPE, is_sentence_transformer=True) as hf_model:
56
57
58
        yield hf_model


59
@pytest.mark.asyncio
60
@pytest.mark.parametrize("model_name", [MODEL_NAME])
61
async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str):
62
63
64
65
66
    input_texts = [
        "The chef prepared a delicious meal.",
    ]

    # test single embedding
67
    embedding_response = await client.embeddings.create(
68
69
70
71
        model=model_name,
        input=input_texts,
        encoding_format="float",
    )
72
    embeddings = EmbeddingResponse.model_validate(
73
74
        embedding_response.model_dump(mode="json")
    )
75

76
77
    assert embeddings.id is not None
    assert len(embeddings.data) == 1
78
    assert len(embeddings.data[0].embedding) == 384
79
    assert embeddings.usage.completion_tokens == 0
80
81
    assert embeddings.usage.prompt_tokens == 11
    assert embeddings.usage.total_tokens == 11
82

83
    vllm_outputs = [d.embedding for d in embeddings.data]
84
    run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
85

86
87
    # test using token IDs
    input_tokens = [1, 1, 1, 1, 1]
88
    embedding_response = await client.embeddings.create(
89
90
91
92
        model=model_name,
        input=input_tokens,
        encoding_format="float",
    )
93
    embeddings = EmbeddingResponse.model_validate(
94
95
        embedding_response.model_dump(mode="json")
    )
96

97
98
    assert embeddings.id is not None
    assert len(embeddings.data) == 1
99
    assert len(embeddings.data[0].embedding) == 384
100
101
102
103
104
105
    assert embeddings.usage.completion_tokens == 0
    assert embeddings.usage.prompt_tokens == 5
    assert embeddings.usage.total_tokens == 5


@pytest.mark.asyncio
106
@pytest.mark.parametrize("model_name", [MODEL_NAME])
107
async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str):
108
    # test list[str]
109
    input_texts = [
110
111
112
        "The cat sat on the mat.",
        "A feline was resting on a rug.",
        "Stars twinkle brightly in the night sky.",
113
    ]
114
    embedding_response = await client.embeddings.create(
115
116
117
118
        model=model_name,
        input=input_texts,
        encoding_format="float",
    )
119
    embeddings = EmbeddingResponse.model_validate(
120
121
        embedding_response.model_dump(mode="json")
    )
122

123
124
    assert embeddings.id is not None
    assert len(embeddings.data) == 3
125
    assert len(embeddings.data[0].embedding) == 384
126
    assert embeddings.usage.completion_tokens == 0
127
128
    assert embeddings.usage.prompt_tokens == 33
    assert embeddings.usage.total_tokens == 33
129

130
    vllm_outputs = [d.embedding for d in embeddings.data]
131
    run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
132

133
    # test list[list[int]]
134
135
136
137
138
139
    input_tokens = [
        [4, 5, 7, 9, 20],
        [15, 29, 499],
        [24, 24, 24, 24, 24],
        [25, 32, 64, 77],
    ]
140
    embedding_response = await client.embeddings.create(
141
142
143
144
        model=model_name,
        input=input_tokens,
        encoding_format="float",
    )
145
    embeddings = EmbeddingResponse.model_validate(
146
147
        embedding_response.model_dump(mode="json")
    )
148

149
150
    assert embeddings.id is not None
    assert len(embeddings.data) == 4
151
    assert len(embeddings.data[0].embedding) == 384
152
153
154
    assert embeddings.usage.completion_tokens == 0
    assert embeddings.usage.prompt_tokens == 17
    assert embeddings.usage.total_tokens == 17
155
156
157


@pytest.mark.asyncio
158
@pytest.mark.parametrize("model_name", [MODEL_NAME])
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
async def test_conversation_embedding(
    server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
    messages = [
        {
            "role": "user",
            "content": "The cat sat on the mat.",
        },
        {
            "role": "assistant",
            "content": "A feline was resting on a rug.",
        },
        {
            "role": "user",
            "content": "Stars twinkle brightly in the night sky.",
        },
    ]
176

177
178
179
180
181
182
183
184
    chat_response = requests.post(
        server.url_for("v1/embeddings"),
        json={
            "model": model_name,
            "messages": messages,
            "encoding_format": "float",
        },
    )
185
    chat_response.raise_for_status()
186
    chat_embeddings = EmbeddingResponse.model_validate(chat_response.json())
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202

    tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
    prompt = tokenizer.apply_chat_template(
        messages,
        chat_template=DUMMY_CHAT_TEMPLATE,
        add_generation_prompt=True,
        continue_final_message=False,
        tokenize=False,
    )
    completion_response = await client.embeddings.create(
        model=model_name,
        input=prompt,
        encoding_format="float",
        # To be consistent with chat
        extra_body={"add_special_tokens": False},
    )
203
    completion_embeddings = EmbeddingResponse.model_validate(
204
205
        completion_response.model_dump(mode="json")
    )
206

207
208
209
    assert chat_embeddings.id is not None
    assert completion_embeddings.id is not None
    assert chat_embeddings.created <= completion_embeddings.created
210
211
212
    assert chat_embeddings.model_dump(exclude={"id", "created"}) == (
        completion_embeddings.model_dump(exclude={"id", "created"})
    )
213
214
215
216


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
217
218
219
async def test_batch_base64_embedding(
    hf_model, client: openai.AsyncOpenAI, model_name: str
):
220
221
    input_texts = [
        "Hello my name is",
222
        "The best thing about vLLM is that it supports many different models",
223
224
    ]

225
226
227
    responses_float = await client.embeddings.create(
        input=input_texts, model=model_name, encoding_format="float"
    )
228
    float_data = [d.embedding for d in responses_float.data]
229
    run_embedding_correctness_test(hf_model, input_texts, float_data)
230

231
232
233
    responses_base64 = await client.embeddings.create(
        input=input_texts, model=model_name, encoding_format="base64"
    )
234
    base64_data = []
235
    for data in responses_base64.data:
236
        base64_data.append(
237
238
            np.frombuffer(base64.b64decode(data.embedding), dtype="float32").tolist()
        )
239

240
    run_embedding_correctness_test(hf_model, input_texts, base64_data)
241
242

    # Default response is float32 decoded from base64 by OpenAI Client
243
244
245
    responses_default = await client.embeddings.create(
        input=input_texts, model=model_name
    )
246
    default_data = [d.embedding for d in responses_default.data]
247
    run_embedding_correctness_test(hf_model, input_texts, default_data)
248
249


250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype(
    hf_model, server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
    input_texts = [
        "The best thing about vLLM is that it supports many different models",
    ]

    responses_float = await client.embeddings.create(
        input=input_texts, model=model_name, encoding_format="float"
    )
    float_data = [d.embedding for d in responses_float.data]

    for embed_dtype, torch_dtype in EMBED_DTYPE_TO_TORCH_DTYPE.items():
        responses_base64 = requests.post(
            server.url_for("/v1/embeddings"),
            json={
                "model": model_name,
                "input": input_texts,
                "encoding_format": "base64",
                "embed_dtype": embed_dtype,
            },
        )

        base64_data = []
        for data in responses_base64.json()["data"]:
            base64_data.append(
                torch.frombuffer(base64.b64decode(data["embedding"]), dtype=torch_dtype)
                .to(torch.float32)
                .tolist()
            )

        check_embeddings_close(
            embeddings_0_lst=float_data,
            embeddings_1_lst=base64_data,
            name_0="float_data",
            name_1="base64_data",
            tol=1e-2,
        )


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embed_dtype_not_supported(
    hf_model, server: RemoteOpenAIServer, model_name: str
):
    input_texts = [
        "The best thing about vLLM is that it supports many different models",
    ]

    bad_embed_dtype = "bad_embed_dtype"

    responses_base64 = requests.post(
        server.url_for("/v1/embeddings"),
        json={
            "model": model_name,
            "input": input_texts,
            "encoding_format": "base64",
            "embed_dtype": bad_embed_dtype,
        },
    )

    assert responses_base64.status_code == 400
    assert responses_base64.json()["error"]["message"].startswith(
        f"embed_dtype={bad_embed_dtype!r} is not supported."
    )


319
@pytest.mark.asyncio
320
@pytest.mark.parametrize("model_name", [MODEL_NAME])
321
async def test_single_embedding_truncation(client: openai.AsyncOpenAI, model_name: str):
322
323
324
325
326
    input_texts = [
        "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
    ]

    # test single embedding
327
    embedding_response = await client.embeddings.create(
328
329
        model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 10}
    )
330
    embeddings = EmbeddingResponse.model_validate(
331
332
        embedding_response.model_dump(mode="json")
    )
333

334
335
    assert embeddings.id is not None
    assert len(embeddings.data) == 1
336
    assert len(embeddings.data[0].embedding) == 384
337
338
339
340
341
    assert embeddings.usage.completion_tokens == 0
    assert embeddings.usage.prompt_tokens == 10
    assert embeddings.usage.total_tokens == 10

    input_tokens = [
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
        1,
        24428,
        289,
        18341,
        26165,
        285,
        19323,
        283,
        289,
        26789,
        3871,
        28728,
        9901,
        340,
        2229,
        385,
        340,
        315,
        28741,
        28804,
        2,
363
    ]
364
    embedding_response = await client.embeddings.create(
365
366
        model=model_name, input=input_tokens, extra_body={"truncate_prompt_tokens": 10}
    )
367
    embeddings = EmbeddingResponse.model_validate(
368
369
        embedding_response.model_dump(mode="json")
    )
370
371
372

    assert embeddings.id is not None
    assert len(embeddings.data) == 1
373
    assert len(embeddings.data[0].embedding) == 384
374
375
376
377
378
379
    assert embeddings.usage.completion_tokens == 0
    assert embeddings.usage.prompt_tokens == 10
    assert embeddings.usage.total_tokens == 10


@pytest.mark.asyncio
380
@pytest.mark.parametrize("model_name", [MODEL_NAME])
381
382
383
async def test_single_embedding_truncation_invalid(
    client: openai.AsyncOpenAI, model_name: str
):
384
385
386
387
388
    input_texts = [
        "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
    ]

    with pytest.raises(openai.BadRequestError):
389
        response = await client.embeddings.create(
390
391
            model=model_name,
            input=input_texts,
392
393
            extra_body={"truncate_prompt_tokens": 8193},
        )
394
        assert "error" in response.object
395
396
397
398
        assert (
            "truncate_prompt_tokens value is greater than max_model_len. "
            "Please, select a smaller truncation size." in response.message
        )
399
400
401


@pytest.mark.asyncio
402
async def test_invocations(server: RemoteOpenAIServer, client: openai.AsyncOpenAI):
403
404
405
406
407
408
409
410
411
412
413
414
    input_texts = [
        "The chef prepared a delicious meal.",
    ]

    request_args = {
        "model": MODEL_NAME,
        "input": input_texts,
        "encoding_format": "float",
    }

    completion_response = await client.embeddings.create(**request_args)

415
416
417
    invocation_response = requests.post(
        server.url_for("invocations"), json=request_args
    )
418
419
420
421
422
423
    invocation_response.raise_for_status()

    completion_output = completion_response.model_dump()
    invocation_output = invocation_response.json()

    assert completion_output.keys() == invocation_output.keys()
424
425
426
    for completion_data, invocation_data in zip(
        completion_output["data"], invocation_output["data"]
    ):
427
        assert completion_data.keys() == invocation_data.keys()
428
429
430
431
432
433
        check_embeddings_close(
            embeddings_0_lst=[completion_data["embedding"]],
            embeddings_1_lst=[invocation_data["embedding"]],
            name_0="completion",
            name_1="invocation",
        )
434
435
436
437


@pytest.mark.asyncio
async def test_invocations_conversation(server: RemoteOpenAIServer):
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    messages = [
        {
            "role": "user",
            "content": "The cat sat on the mat.",
        },
        {
            "role": "assistant",
            "content": "A feline was resting on a rug.",
        },
        {
            "role": "user",
            "content": "Stars twinkle brightly in the night sky.",
        },
    ]
452
453
454
455
456
457
458

    request_args = {
        "model": MODEL_NAME,
        "messages": messages,
        "encoding_format": "float",
    }

459
    chat_response = requests.post(server.url_for("v1/embeddings"), json=request_args)
460
461
    chat_response.raise_for_status()

462
463
464
    invocation_response = requests.post(
        server.url_for("invocations"), json=request_args
    )
465
466
467
468
469
470
    invocation_response.raise_for_status()

    chat_output = chat_response.json()
    invocation_output = invocation_response.json()

    assert chat_output.keys() == invocation_output.keys()
471
472
473
    for chat_data, invocation_data in zip(
        chat_output["data"], invocation_output["data"]
    ):
474
        assert chat_data.keys() == invocation_data.keys()
475
476
477
478
479
480
        check_embeddings_close(
            embeddings_0_lst=[chat_data["embedding"]],
            embeddings_1_lst=[invocation_data["embedding"]],
            name_0="chat",
            name_1="invocation",
        )
481
482
483
484
485
486
487
488
489
490
491
492


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_normalize(server: RemoteOpenAIServer, model_name: str):
    input_text = ["The chef prepared a delicious meal."]

    async def get_outputs(normalize):
        request_args = {
            "model": MODEL_NAME,
            "input": input_text,
            "encoding_format": "float",
493
            "normalize": normalize,
494
495
        }

496
        response = requests.post(server.url_for("v1/embeddings"), json=request_args)
497
498
        outputs = response.json()

499
        return torch.tensor([x["embedding"] for x in outputs["data"]])
500
501
502
503
504

    default = await get_outputs(normalize=None)
    w_normal = await get_outputs(normalize=True)
    wo_normal = await get_outputs(normalize=False)

505
506
507
508
509
510
511
    assert torch.allclose(default, w_normal, atol=1e-2), "Default should use normal."
    assert not torch.allclose(w_normal, wo_normal, atol=1e-2), (
        "wo_normal should not use normal."
    )
    assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
        "w_normal should be close to normal(wo_normal)."
    )