test_offline.py 3.83 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import logging
4
5
6
7
8
9
import weakref

import pytest
import torch
import torch.nn.functional as F

10
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
11
from vllm.distributed import cleanup_dist_env_and_memory
12
from vllm.platforms import current_platform
13
from vllm.tasks import PoolingTask
14

15
16
MODEL_NAME = "intfloat/multilingual-e5-small"

17
18
19
prompt = "The chef prepared a delicious meal."
prompt_token_ids = [0, 581, 21861, 133888, 10, 8, 150, 60744, 109911, 5, 2]
embedding_size = 384
20
21
22
23


@pytest.fixture(scope="module")
def llm():
24
25
26
27
28
29
    # ROCm: Use FLEX_ATTENTION backend as it's the only attention backend
    # that supports encoder-only models on ROCm.
    attention_config = None
    if current_platform.is_rocm():
        attention_config = {"backend": "FLEX_ATTENTION"}

30
31
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
32
33
34
35
36
37
38
    llm = LLM(
        model=MODEL_NAME,
        max_num_batched_tokens=32768,
        tensor_parallel_size=1,
        gpu_memory_utilization=0.75,
        enforce_eager=True,
        seed=0,
39
        attention_config=attention_config,
40
    )
41

42
    yield weakref.proxy(llm)
43

44
    del llm
45
46
47
48

    cleanup_dist_env_and_memory()


49
@pytest.mark.skip_global_cleanup
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def test_str_prompts(llm: LLM):
    outputs = llm.embed(prompt, use_tqdm=False)
    assert len(outputs) == 1
    assert isinstance(outputs[0], EmbeddingRequestOutput)
    assert outputs[0].prompt_token_ids == prompt_token_ids
    assert len(outputs[0].outputs.embedding) == embedding_size


@pytest.mark.skip_global_cleanup
def test_token_ids_prompts(llm: LLM):
    outputs = llm.embed([prompt_token_ids], use_tqdm=False)
    assert len(outputs) == 1
    assert isinstance(outputs[0], EmbeddingRequestOutput)
    assert outputs[0].prompt_token_ids == prompt_token_ids
    assert len(outputs[0].outputs.embedding) == embedding_size


@pytest.mark.skip_global_cleanup
def test_list_prompts(llm: LLM):
    outputs = llm.embed([prompt, prompt_token_ids], use_tqdm=False)
    assert len(outputs) == 2
    for i in range(len(outputs)):
        assert isinstance(outputs[i], EmbeddingRequestOutput)
        assert outputs[i].prompt_token_ids == prompt_token_ids
        assert len(outputs[i].outputs.embedding) == embedding_size


@pytest.mark.skip_global_cleanup
def test_token_embed(llm: LLM, caplog_vllm):
    with caplog_vllm.at_level(level=logging.WARNING, logger="vllm"):
        outputs = llm.encode(prompt, pooling_task="token_embed", use_tqdm=False)
        assert "deprecated" in caplog_vllm.text

83
84
85
86
    multi_vector = outputs[0].outputs.data
    assert multi_vector.shape == (11, 384)


87
@pytest.mark.skip_global_cleanup
88
89
def test_pooling_params(llm: LLM):
    def get_outputs(normalize):
90
        outputs = llm.embed(
91
            [prompt],
92
93
            pooling_params=PoolingParams(use_activation=normalize),
            use_tqdm=False,
94
        )
95
96
97
98
99
100
        return torch.tensor([x.outputs.embedding for x in outputs])

    default = get_outputs(normalize=None)
    w_normal = get_outputs(normalize=True)
    wo_normal = get_outputs(normalize=False)

101
102
103
104
105
106
107
    assert torch.allclose(default, w_normal, atol=1e-2), "Default should use normal."
    assert not torch.allclose(w_normal, wo_normal, atol=1e-2), (
        "wo_normal should not use normal."
    )
    assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
        "w_normal should be close to normal(wo_normal)."
    )
108
109


110
@pytest.mark.parametrize("task", ["token_classify", "classify", "plugin"])
111
def test_unsupported_tasks(llm: LLM, task: PoolingTask):
112
113
114
115
    if task == "plugin":
        err_msg = "No IOProcessor plugin installed."
    else:
        err_msg = "Classification API is not supported by this model.+"
116
117
    with pytest.raises(ValueError, match=err_msg):
        llm.encode(prompt, pooling_task=task, use_tqdm=False)