test_offline.py 2.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import weakref

import pytest
import torch
import torch.nn.functional as F

from vllm import LLM, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory
12
13
from vllm.platforms import current_platform

14
15
16
17
18
19
20
MODEL_NAME = "intfloat/multilingual-e5-small"

prompts = ["The chef prepared a delicious meal."]


@pytest.fixture(scope="module")
def llm():
21
22
23
24
25
26
    # ROCm: Use FLEX_ATTENTION backend as it's the only attention backend
    # that supports encoder-only models on ROCm.
    attention_config = None
    if current_platform.is_rocm():
        attention_config = {"backend": "FLEX_ATTENTION"}

27
28
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
29
30
31
32
33
34
35
    llm = LLM(
        model=MODEL_NAME,
        max_num_batched_tokens=32768,
        tensor_parallel_size=1,
        gpu_memory_utilization=0.75,
        enforce_eager=True,
        seed=0,
36
        attention_config=attention_config,
37
    )
38

39
    yield weakref.proxy(llm)
40

41
    del llm
42
43
44
45

    cleanup_dist_env_and_memory()


46
@pytest.mark.skip_global_cleanup
47
def test_token_embed(llm: LLM):
48
49
50
51
52
    outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
    multi_vector = outputs[0].outputs.data
    assert multi_vector.shape == (11, 384)


53
54
def test_pooling_params(llm: LLM):
    def get_outputs(normalize):
55
        outputs = llm.embed(
56
57
58
            prompts,
            pooling_params=PoolingParams(use_activation=normalize),
            use_tqdm=False,
59
        )
60
61
62
63
64
65
        return torch.tensor([x.outputs.embedding for x in outputs])

    default = get_outputs(normalize=None)
    w_normal = get_outputs(normalize=True)
    wo_normal = get_outputs(normalize=False)

66
67
68
69
70
71
72
    assert torch.allclose(default, w_normal, atol=1e-2), "Default should use normal."
    assert not torch.allclose(w_normal, wo_normal, atol=1e-2), (
        "wo_normal should not use normal."
    )
    assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
        "w_normal should be close to normal(wo_normal)."
    )