test_offline.py 3.63 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import weakref

import pytest
import torch
import torch.nn.functional as F

9
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
10
from vllm.distributed import cleanup_dist_env_and_memory
11
from vllm.platforms import current_platform
12
from vllm.tasks import PoolingTask
13

14
15
MODEL_NAME = "intfloat/multilingual-e5-small"

16
17
18
prompt = "The chef prepared a delicious meal."
prompt_token_ids = [0, 581, 21861, 133888, 10, 8, 150, 60744, 109911, 5, 2]
embedding_size = 384
19
20
21
22


@pytest.fixture(scope="module")
def llm():
23
24
25
26
27
28
    # ROCm: Use FLEX_ATTENTION backend as it's the only attention backend
    # that supports encoder-only models on ROCm.
    attention_config = None
    if current_platform.is_rocm():
        attention_config = {"backend": "FLEX_ATTENTION"}

29
30
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
31
32
33
34
35
36
37
    llm = LLM(
        model=MODEL_NAME,
        max_num_batched_tokens=32768,
        tensor_parallel_size=1,
        gpu_memory_utilization=0.75,
        enforce_eager=True,
        seed=0,
38
        attention_config=attention_config,
39
    )
40
    assert embedding_size == llm.model_config.embedding_size
41

42
    yield weakref.proxy(llm)
43

44
    del llm
45
46
47
    cleanup_dist_env_and_memory()


48
@pytest.mark.skip_global_cleanup
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def test_str_prompts(llm: LLM):
    outputs = llm.embed(prompt, use_tqdm=False)
    assert len(outputs) == 1
    assert isinstance(outputs[0], EmbeddingRequestOutput)
    assert outputs[0].prompt_token_ids == prompt_token_ids
    assert len(outputs[0].outputs.embedding) == embedding_size


@pytest.mark.skip_global_cleanup
def test_token_ids_prompts(llm: LLM):
    outputs = llm.embed([prompt_token_ids], use_tqdm=False)
    assert len(outputs) == 1
    assert isinstance(outputs[0], EmbeddingRequestOutput)
    assert outputs[0].prompt_token_ids == prompt_token_ids
    assert len(outputs[0].outputs.embedding) == embedding_size


@pytest.mark.skip_global_cleanup
def test_list_prompts(llm: LLM):
    outputs = llm.embed([prompt, prompt_token_ids], use_tqdm=False)
    assert len(outputs) == 2
    for i in range(len(outputs)):
        assert isinstance(outputs[i], EmbeddingRequestOutput)
        assert outputs[i].prompt_token_ids == prompt_token_ids
        assert len(outputs[i].outputs.embedding) == embedding_size


@pytest.mark.skip_global_cleanup
77
78
def test_pooling_params(llm: LLM):
    def get_outputs(normalize):
79
        outputs = llm.embed(
80
            [prompt],
81
82
            pooling_params=PoolingParams(use_activation=normalize),
            use_tqdm=False,
83
        )
84
85
86
87
88
89
        return torch.tensor([x.outputs.embedding for x in outputs])

    default = get_outputs(normalize=None)
    w_normal = get_outputs(normalize=True)
    wo_normal = get_outputs(normalize=False)

90
91
92
93
94
95
96
    assert torch.allclose(default, w_normal, atol=1e-2), "Default should use normal."
    assert not torch.allclose(w_normal, wo_normal, atol=1e-2), (
        "wo_normal should not use normal."
    )
    assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
        "w_normal should be close to normal(wo_normal)."
    )
97
98


99
100
101
@pytest.mark.parametrize(
    "task", ["token_classify", "classify", "token_embed", "plugin"]
)
102
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
103
104
    if task == "plugin":
        err_msg = "No IOProcessor plugin installed."
105
106
    elif task == "token_embed":
        err_msg = "Try switching the model's pooling_task via.+"
107
108
    else:
        err_msg = "Classification API is not supported by this model.+"
109
110
    with pytest.raises(ValueError, match=err_msg):
        llm.encode(prompt, pooling_task=task, use_tqdm=False)