test_online.py 23.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import base64
5
import json
6
7

import numpy as np
8
9
import openai
import pytest
10
import pytest_asyncio
11
import requests
12
13
import torch
import torch.nn.functional as F
14

15
from tests.models.language.pooling.embed_utils import run_embedding_correctness_test
16
17
from tests.models.utils import check_embeddings_close
from tests.utils import RemoteOpenAIServer
18
19
from vllm.entrypoints.pooling.embed.protocol import EmbeddingResponse
from vllm.entrypoints.pooling.pooling.protocol import PoolingResponse
20
from vllm.platforms import current_platform
21
from vllm.tokenizers import get_tokenizer
22
23
24
25
26
from vllm.utils.serial_utils import (
    EMBED_DTYPE_TO_TORCH_DTYPE,
    ENDIANNESS,
    MetadataItem,
    binary2tensor,
27
    build_metadata_items,
28
29
    decode_pooling_output,
)
30

31
MODEL_NAME = "intfloat/multilingual-e5-small"
32
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}"""  # noqa: E501
33
DTYPE = "bfloat16"
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
input_text = "The best thing about vLLM is that it supports many different models"
input_tokens = [
    0,
    581,
    2965,
    13580,
    1672,
    81,
    23708,
    594,
    83,
    450,
    442,
    8060,
    7,
    5941,
    12921,
    115774,
    2,
]
54

55
56
57
58
59
60
61
62
63
if current_platform.is_rocm():
    # Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
    # accuracy issues: https://github.com/vllm-project/vllm/issues/30167
    # TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)


64
@pytest.fixture(scope="module")
65
def server():
66
    args = [
67
68
        "--runner",
        "pooling",
69
70
        # use half precision for speed and memory savings in CI environment
        "--dtype",
71
        DTYPE,
72
73
        "--enforce-eager",
        "--max-model-len",
74
        "512",
75
76
        "--chat-template",
        DUMMY_CHAT_TEMPLATE,
77
78
    ]

79
80
81
82
    # ROCm: Use Flex Attention to support encoder-only self-attention.
    if current_platform.is_rocm():
        args.extend(["--attention-backend", "FLEX_ATTENTION"])

83
    with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
84
        yield remote_server
85
86


87
@pytest_asyncio.fixture
88
89
async def client(server):
    async with server.get_async_client() as async_client:
90
        yield async_client
91
92


93
94
@pytest.fixture(scope="module")
def hf_model(hf_runner):
95
    with hf_runner(MODEL_NAME, dtype=DTYPE, is_sentence_transformer=True) as hf_model:
96
97
98
        yield hf_model


99
@pytest.mark.asyncio
100
@pytest.mark.parametrize("model_name", [MODEL_NAME])
101
102
103
104
105
106
107
async def test_basic(
    server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
    # test /v1/models
    response = requests.get(server.url_for("/v1/models"))
    model = response.json()["data"][0]["id"]
    assert model == MODEL_NAME
108

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    models = await client.models.list()
    models = models.data
    served_model = models[0]
    assert served_model.id == MODEL_NAME

    # test /tokenize
    response = requests.post(
        server.url_for("/tokenize"),
        json={"model": model_name, "prompt": input_text},
    )
    assert response.json()["tokens"] == input_tokens


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_completion_request(
    client: openai.AsyncOpenAI, model_name: str, hf_model
):
    # test input: str
128
    embedding_response = await client.embeddings.create(
129
        model=model_name,
130
        input=input_text,
131
132
        encoding_format="float",
    )
133
    embeddings = EmbeddingResponse.model_validate(
134
135
        embedding_response.model_dump(mode="json")
    )
136

137
138
    assert embeddings.id is not None
    assert len(embeddings.data) == 1
139
    assert len(embeddings.data[0].embedding) == 384
140
    assert embeddings.usage.completion_tokens == 0
141
142
    assert embeddings.usage.prompt_tokens == len(input_tokens)
    assert embeddings.usage.total_tokens == len(input_tokens)
143

144
    vllm_outputs = [d.embedding for d in embeddings.data]
145
    run_embedding_correctness_test(hf_model, [input_text], vllm_outputs)
146

147
    # test input: list[int]
148
    embedding_response = await client.embeddings.create(
149
150
151
152
        model=model_name,
        input=input_tokens,
        encoding_format="float",
    )
153
    embeddings = EmbeddingResponse.model_validate(
154
155
        embedding_response.model_dump(mode="json")
    )
156

157
158
    assert embeddings.id is not None
    assert len(embeddings.data) == 1
159
    assert len(embeddings.data[0].embedding) == 384
160
    assert embeddings.usage.completion_tokens == 0
161
162
163
164
165
    assert embeddings.usage.prompt_tokens == len(input_tokens)
    assert embeddings.usage.total_tokens == len(input_tokens)

    vllm_outputs = [d.embedding for d in embeddings.data]
    run_embedding_correctness_test(hf_model, [input_text], vllm_outputs)
166
167
168


@pytest.mark.asyncio
169
@pytest.mark.parametrize("model_name", [MODEL_NAME])
170
171
172
173
174
175
176
async def test_completion_request_batched(
    client: openai.AsyncOpenAI, model_name: str, hf_model
):
    N = 10
    input_texts = [input_text] * N

    # test input: list[str]
177
    embedding_response = await client.embeddings.create(
178
179
180
181
        model=model_name,
        input=input_texts,
        encoding_format="float",
    )
182
    embeddings = EmbeddingResponse.model_validate(
183
184
        embedding_response.model_dump(mode="json")
    )
185

186
    assert embeddings.id is not None
187
    assert len(embeddings.data) == N
188
    assert len(embeddings.data[0].embedding) == 384
189
    assert embeddings.usage.completion_tokens == 0
190
191
    assert embeddings.usage.prompt_tokens == len(input_tokens) * N
    assert embeddings.usage.total_tokens == len(input_tokens) * N
192

193
    vllm_outputs = [d.embedding for d in embeddings.data]
194
    run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
195

196
    # test list[list[int]]
197
    embedding_response = await client.embeddings.create(
198
        model=model_name,
199
        input=[input_tokens] * N,
200
201
        encoding_format="float",
    )
202
    embeddings = EmbeddingResponse.model_validate(
203
204
        embedding_response.model_dump(mode="json")
    )
205

206
    assert embeddings.id is not None
207
    assert len(embeddings.data) == N
208
    assert len(embeddings.data[0].embedding) == 384
209
    assert embeddings.usage.completion_tokens == 0
210
211
212
213
214
    assert embeddings.usage.prompt_tokens == len(input_tokens) * N
    assert embeddings.usage.total_tokens == len(input_tokens) * N

    vllm_outputs = [d.embedding for d in embeddings.data]
    run_embedding_correctness_test(hf_model, input_texts, vllm_outputs)
215
216


217
218
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
async def test_truncate_prompt_tokens(client: openai.AsyncOpenAI, model_name: str):
    input_texts = [
        "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
    ]

    # test single embedding
    embedding_response = await client.embeddings.create(
        model=model_name, input=input_texts, extra_body={"truncate_prompt_tokens": 10}
    )
    embeddings = EmbeddingResponse.model_validate(
        embedding_response.model_dump(mode="json")
    )

    assert embeddings.id is not None
    assert len(embeddings.data) == 1
    assert len(embeddings.data[0].embedding) == 384
    assert embeddings.usage.completion_tokens == 0
    assert embeddings.usage.prompt_tokens == 10
    assert embeddings.usage.total_tokens == 10

    input_tokens = [
        1,
        24428,
        289,
        18341,
        26165,
        285,
        19323,
        283,
        289,
        26789,
        3871,
        28728,
        9901,
        340,
        2229,
        385,
        340,
        315,
        28741,
        28804,
        2,
    ]
    embedding_response = await client.embeddings.create(
        model=model_name, input=input_tokens, extra_body={"truncate_prompt_tokens": 10}
    )
    embeddings = EmbeddingResponse.model_validate(
        embedding_response.model_dump(mode="json")
    )

    assert embeddings.id is not None
    assert len(embeddings.data) == 1
    assert len(embeddings.data[0].embedding) == 384
    assert embeddings.usage.completion_tokens == 0
    assert embeddings.usage.prompt_tokens == 10
    assert embeddings.usage.total_tokens == 10

    # invalid_truncate_prompt_tokens
    input_texts = [
        "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
    ]

    with pytest.raises(openai.BadRequestError):
        response = await client.embeddings.create(
            model=model_name,
            input=input_texts,
            extra_body={"truncate_prompt_tokens": 8193},
        )
        assert "error" in response.object
        assert (
            "truncate_prompt_tokens value is greater than max_model_len. "
            "Please, select a smaller truncation size." in response.message
        )


@pytest.mark.asyncio
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_chat_request(
    server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
    messages = [
        {
            "role": "user",
            "content": "The cat sat on the mat.",
        },
        {
            "role": "assistant",
            "content": "A feline was resting on a rug.",
        },
        {
            "role": "user",
            "content": "Stars twinkle brightly in the night sky.",
        },
    ]

    # test chat request basic usage
    chat_response = requests.post(
        server.url_for("v1/embeddings"),
        json={
            "model": model_name,
            "messages": messages,
            "encoding_format": "float",
        },
    )
    chat_response.raise_for_status()
    chat_embeddings = EmbeddingResponse.model_validate(chat_response.json())

    tokenizer = get_tokenizer(tokenizer_name=model_name)
    prompt = tokenizer.apply_chat_template(
        messages,
        chat_template=DUMMY_CHAT_TEMPLATE,
        add_generation_prompt=True,
        continue_final_message=False,
        tokenize=False,
    )
    completion_response = await client.embeddings.create(
        model=model_name,
        input=prompt,
        encoding_format="float",
        # To be consistent with chat
        extra_body={"add_special_tokens": False},
    )
    completion_embeddings = EmbeddingResponse.model_validate(
        completion_response.model_dump(mode="json")
    )

    assert chat_embeddings.id is not None
    assert completion_embeddings.id is not None
    assert chat_embeddings.created <= completion_embeddings.created
    assert chat_embeddings.model_dump(exclude={"id", "created"}) == (
        completion_embeddings.model_dump(exclude={"id", "created"})
    )

    # test add_generation_prompt
    response = requests.post(
        server.url_for("v1/embeddings"),
        json={"model": model_name, "messages": messages, "add_generation_prompt": True},
    )

    response.raise_for_status()
    output = EmbeddingResponse.model_validate(response.json())

    assert output.object == "list"
    assert len(output.data) == 1
    assert output.model == MODEL_NAME
    assert output.usage.prompt_tokens == 34

    # test continue_final_message
    response = requests.post(
        server.url_for("v1/embeddings"),
        json={
            "model": model_name,
            "messages": messages,
            "continue_final_message": True,
        },
    )

    response.raise_for_status()
    output = EmbeddingResponse.model_validate(response.json())

    assert output.object == "list"
    assert len(output.data) == 1
    assert output.model == MODEL_NAME
    assert output.usage.prompt_tokens == 33

    # test add_special_tokens
    response = requests.post(
        server.url_for("v1/embeddings"),
        json={"model": model_name, "messages": messages, "add_special_tokens": True},
    )

    response.raise_for_status()
    output = EmbeddingResponse.model_validate(response.json())

    assert output.object == "list"
    assert len(output.data) == 1
    assert output.model == MODEL_NAME
    assert output.usage.prompt_tokens == 36

    # test continue_final_message with add_generation_prompt
    response = requests.post(
        server.url_for("v1/embeddings"),
        json={
            "model": model_name,
            "messages": messages,
            "continue_final_message": True,
            "add_generation_prompt": True,
        },
    )
    assert (
        "Cannot set both `continue_final_message` and `add_generation_prompt` to True."
        in response.json()["error"]["message"]
    )


@pytest.mark.asyncio
async def test_invocations_completion_request(
    server: RemoteOpenAIServer, client: openai.AsyncOpenAI
):
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
    request_args = {
        "model": MODEL_NAME,
        "input": input_text,
        "encoding_format": "float",
    }

    completion_response = await client.embeddings.create(**request_args)

    invocation_response = requests.post(
        server.url_for("invocations"), json=request_args
    )
    invocation_response.raise_for_status()

    completion_output = completion_response.model_dump()
    invocation_output = invocation_response.json()

    assert completion_output.keys() == invocation_output.keys()
    for completion_data, invocation_data in zip(
        completion_output["data"], invocation_output["data"]
    ):
        assert completion_data.keys() == invocation_data.keys()
        check_embeddings_close(
            embeddings_0_lst=[completion_data["embedding"]],
            embeddings_1_lst=[invocation_data["embedding"]],
            name_0="completion",
            name_1="invocation",
        )


@pytest.mark.asyncio
448
async def test_invocations_chat_request(server: RemoteOpenAIServer):
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
    messages = [
        {
            "role": "user",
            "content": "The cat sat on the mat.",
        },
        {
            "role": "assistant",
            "content": "A feline was resting on a rug.",
        },
        {
            "role": "user",
            "content": "Stars twinkle brightly in the night sky.",
        },
    ]

    request_args = {
        "model": MODEL_NAME,
        "messages": messages,
        "encoding_format": "float",
    }

    chat_response = requests.post(server.url_for("v1/embeddings"), json=request_args)
    chat_response.raise_for_status()

    invocation_response = requests.post(
        server.url_for("invocations"), json=request_args
    )
    invocation_response.raise_for_status()

    chat_output = chat_response.json()
    invocation_output = invocation_response.json()

    assert chat_output.keys() == invocation_output.keys()
    for chat_data, invocation_data in zip(
        chat_output["data"], invocation_output["data"]
    ):
        assert chat_data.keys() == invocation_data.keys()
        check_embeddings_close(
            embeddings_0_lst=[chat_data["embedding"]],
            embeddings_1_lst=[invocation_data["embedding"]],
            name_0="chat",
            name_1="invocation",
        )


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_base64_embedding(hf_model, client: openai.AsyncOpenAI, model_name: str):
497
498
    input_texts = [
        "Hello my name is",
499
        "The best thing about vLLM is that it supports many different models",
500
501
    ]

502
503
504
    responses_float = await client.embeddings.create(
        input=input_texts, model=model_name, encoding_format="float"
    )
505
    float_data = [d.embedding for d in responses_float.data]
506
    run_embedding_correctness_test(hf_model, input_texts, float_data)
507

508
509
510
    responses_base64 = await client.embeddings.create(
        input=input_texts, model=model_name, encoding_format="base64"
    )
511
    base64_data = []
512
    for data in responses_base64.data:
513
        base64_data.append(
514
515
            np.frombuffer(base64.b64decode(data.embedding), dtype="float32").tolist()
        )
516

517
    run_embedding_correctness_test(hf_model, input_texts, base64_data)
518
519

    # Default response is float32 decoded from base64 by OpenAI Client
520
521
522
    responses_default = await client.embeddings.create(
        input=input_texts, model=model_name
    )
523
    default_data = [d.embedding for d in responses_default.data]
524
    run_embedding_correctness_test(hf_model, input_texts, default_data)
525
526


527
528
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
529
530
async def test_base64_embed_dtype_and_endianness(
    server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
531
):
532
    input_texts = [input_text] * 3
533
534
535
536
537
    responses_float = await client.embeddings.create(
        input=input_texts, model=model_name, encoding_format="float"
    )
    float_data = [d.embedding for d in responses_float.data]

538
539
540
541
542
543
544
545
546
547
548
    for embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE:
        for endianness in ENDIANNESS:
            responses_base64 = requests.post(
                server.url_for("/v1/embeddings"),
                json={
                    "model": model_name,
                    "input": input_texts,
                    "encoding_format": "base64",
                    "embed_dtype": embed_dtype,
                    "endianness": endianness,
                },
549
550
            )

551
552
553
554
555
556
557
558
559
560
561
562
563
            base64_data = []
            for data in responses_base64.json()["data"]:
                binary = base64.b64decode(data["embedding"])
                tensor = binary2tensor(binary, (-1,), embed_dtype, endianness)
                base64_data.append(tensor.to(torch.float32).tolist())

            check_embeddings_close(
                embeddings_0_lst=float_data,
                embeddings_1_lst=base64_data,
                name_0="float_data",
                name_1="base64_data",
                tol=1e-2,
            )
564
565
566
567


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
568
569
async def test_bytes_embed_dtype_and_endianness(
    server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
570
):
571
    input_texts = [input_text] * 3
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
    responses_float = await client.embeddings.create(
        input=input_texts, model=model_name, encoding_format="float"
    )
    float_data = [d.embedding for d in responses_float.data]

    for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
        for endianness in ENDIANNESS:
            responses_bytes = requests.post(
                server.url_for("/v1/embeddings"),
                json={
                    "model": model_name,
                    "input": input_texts,
                    "encoding_format": "bytes",
                    "embed_dtype": embed_dtype,
                    "endianness": endianness,
                },
            )

            metadata = json.loads(responses_bytes.headers["metadata"])
            body = responses_bytes.content
            items = [MetadataItem(**x) for x in metadata["data"]]

            bytes_data = decode_pooling_output(items=items, body=body)
            bytes_data = [x.to(torch.float32).tolist() for x in bytes_data]

            check_embeddings_close(
                embeddings_0_lst=float_data,
                embeddings_1_lst=bytes_data,
                name_0="float_data",
                name_1="bytes_data",
                tol=1e-2,
            )


606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_bytes_only_embed_dtype_and_endianness(
    server: RemoteOpenAIServer, client: openai.AsyncOpenAI, model_name: str
):
    input_texts = [
        "The best thing about vLLM is that it supports many different models",
    ] * 2

    responses_float = await client.embeddings.create(
        input=input_texts, model=model_name, encoding_format="float"
    )
    float_data = [d.embedding for d in responses_float.data]
    embedding_size = len(float_data[0])

    for embed_dtype in list(EMBED_DTYPE_TO_TORCH_DTYPE.keys()):
        for endianness in ENDIANNESS:
            responses_bytes = requests.post(
                server.url_for("/v1/embeddings"),
                json={
                    "model": model_name,
                    "input": input_texts,
                    "encoding_format": "bytes_only",
                    "embed_dtype": embed_dtype,
                    "endianness": endianness,
                },
            )

            assert "metadata" not in responses_bytes.headers
            body = responses_bytes.content
            items = build_metadata_items(
                embed_dtype=embed_dtype,
                endianness=endianness,
                shape=(embedding_size,),
                n_request=len(input_texts),
            )

            bytes_data = decode_pooling_output(items=items, body=body)
            bytes_data = [x.to(torch.float32).tolist() for x in bytes_data]

            check_embeddings_close(
                embeddings_0_lst=float_data,
                embeddings_1_lst=bytes_data,
                name_0="float_data",
                name_1="bytes_data",
                tol=1e-2,
            )


655
656
657
658
659
660
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("param_name", ["encoding_format", "embed_dtype", "endianness"])
async def test_params_not_supported(
    server: RemoteOpenAIServer, model_name: str, param_name: str
):
661
662
663
664
    responses_base64 = requests.post(
        server.url_for("/v1/embeddings"),
        json={
            "model": model_name,
665
            "input": input_text,
666
            "encoding_format": "base64",
667
            param_name: f"bad_{param_name}",
668
669
670
671
        },
    )

    assert responses_base64.status_code == 400
672
673
    assert "literal_error" in responses_base64.json()["error"]["message"]
    assert f"bad_{param_name}" in responses_base64.json()["error"]["message"]
674
675


676
677
678
679
680
681
682
683
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_normalize(server: RemoteOpenAIServer, model_name: str):
    async def get_outputs(normalize):
        request_args = {
            "model": MODEL_NAME,
            "input": input_text,
            "encoding_format": "float",
684
            "normalize": normalize,
685
686
        }

687
        response = requests.post(server.url_for("v1/embeddings"), json=request_args)
688
689
        outputs = response.json()

690
        return torch.tensor([x["embedding"] for x in outputs["data"]])
691
692
693
694
695

    default = await get_outputs(normalize=None)
    w_normal = await get_outputs(normalize=True)
    wo_normal = await get_outputs(normalize=False)

696
697
698
699
700
701
702
    assert torch.allclose(default, w_normal, atol=1e-2), "Default should use normal."
    assert not torch.allclose(w_normal, wo_normal, atol=1e-2), (
        "wo_normal should not use normal."
    )
    assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
        "w_normal should be close to normal(wo_normal)."
    )
703
704
705
706


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
707
708
async def test_pooling_embed(server: RemoteOpenAIServer, model_name: str):
    task = "embed"
709
710
    response = requests.post(
        server.url_for("pooling"),
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
        json={
            "model": model_name,
            "input": input_text,
            "encoding_format": "float",
            "task": task,
        },
    )

    poolings = PoolingResponse.model_validate(response.json())

    assert len(poolings.data) == 1
    assert len(poolings.data[0].data) == 384


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_pooling_token_embed(server: RemoteOpenAIServer, model_name: str):
    task = "token_embed"
    response = requests.post(
        server.url_for("pooling"),
        json={
            "model": model_name,
            "input": input_text,
            "encoding_format": "float",
            "task": task,
        },
737
738
739
740
741
    )

    poolings = PoolingResponse.model_validate(response.json())

    assert len(poolings.data) == 1
742
    assert len(poolings.data[0].data) == len(input_tokens)
743
    assert len(poolings.data[0].data[0]) == 384
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("task", ["classify", "token_classify", "plugin"])
async def test_pooling_not_supported(
    server: RemoteOpenAIServer, model_name: str, task: str
):
    response = requests.post(
        server.url_for("pooling"),
        json={
            "model": model_name,
            "input": "test",
            "encoding_format": "float",
            "task": task,
        },
    )
    assert response.json()["error"]["type"] == "BadRequestError"
    assert response.json()["error"]["message"].startswith(
        f"Task {task} is not supported"
    )