test_embedding.py 1.72 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import weakref

import pytest
import torch
import torch.nn.functional as F

from vllm import LLM, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory

MODEL_NAME = "intfloat/multilingual-e5-small"

prompts = ["The chef prepared a delicious meal."]


@pytest.fixture(scope="module")
def llm():
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
22
23
24
25
26
27
28
29
    llm = LLM(
        model=MODEL_NAME,
        max_num_batched_tokens=32768,
        tensor_parallel_size=1,
        gpu_memory_utilization=0.75,
        enforce_eager=True,
        seed=0,
    )
30

31
    yield weakref.proxy(llm)
32

33
    del llm
34
35
36
37

    cleanup_dist_env_and_memory()


38
39
40
41
42
43
44
@pytest.mark.skip_global_cleanup
def test_encode_api(llm: LLM):
    outputs = llm.encode(prompts, pooling_task="token_embed", use_tqdm=False)
    multi_vector = outputs[0].outputs.data
    assert multi_vector.shape == (11, 384)


45
46
def test_pooling_params(llm: LLM):
    def get_outputs(normalize):
47
48
49
        outputs = llm.embed(
            prompts, pooling_params=PoolingParams(normalize=normalize), use_tqdm=False
        )
50
51
52
53
54
55
        return torch.tensor([x.outputs.embedding for x in outputs])

    default = get_outputs(normalize=None)
    w_normal = get_outputs(normalize=True)
    wo_normal = get_outputs(normalize=False)

56
57
58
59
60
61
62
    assert torch.allclose(default, w_normal, atol=1e-2), "Default should use normal."
    assert not torch.allclose(w_normal, wo_normal, atol=1e-2), (
        "wo_normal should not use normal."
    )
    assert torch.allclose(w_normal, F.normalize(wo_normal, p=2, dim=-1), atol=1e-2), (
        "w_normal should be close to normal(wo_normal)."
    )