"vscode:/vscode.git/clone" did not exist on "cb0a7b4bea26657da989562a10055b7d0b59fd3a"
test_classify.py 2.06 KB
Newer Older
1
2
3
4
5
6
7
8
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import weakref

import pytest
import torch

9
from tests.models.utils import softmax
10
11
12
13
14
15
16
17
18
19
20
21
from vllm import LLM, PoolingParams
from vllm.distributed import cleanup_dist_env_and_memory

MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"

prompts = ["The chef prepared a delicious meal."]


@pytest.fixture(scope="module")
def llm():
    # pytest caches the fixture so we use weakref.proxy to
    # enable garbage collection
22
23
24
25
26
27
28
29
    llm = LLM(
        model=MODEL_NAME,
        max_num_batched_tokens=32768,
        tensor_parallel_size=1,
        gpu_memory_utilization=0.75,
        enforce_eager=True,
        seed=0,
    )
30

31
    yield weakref.proxy(llm)
32

33
    del llm
34
35
36
37
38
39
40
41

    cleanup_dist_env_and_memory()


@pytest.mark.skip_global_cleanup
def test_pooling_params(llm: LLM):
    def get_outputs(activation):
        outputs = llm.classify(
42
43
            prompts, pooling_params=PoolingParams(activation=activation), use_tqdm=False
        )
44
45
46
47
48
49
        return torch.tensor([x.outputs.probs for x in outputs])

    default = get_outputs(activation=None)
    w_activation = get_outputs(activation=True)
    wo_activation = get_outputs(activation=False)

50
51
52
53
54
55
56
57
58
    assert torch.allclose(default, w_activation, atol=1e-2), (
        "Default should use activation."
    )
    assert not torch.allclose(w_activation, wo_activation, atol=1e-2), (
        "wo_activation should not use activation."
    )
    assert torch.allclose(softmax(wo_activation), w_activation, atol=1e-2), (
        "w_activation should be close to activation(wo_activation)."
    )
59
60


61
@pytest.mark.skip_global_cleanup
62
def test_encode_api(llm: LLM):
63
    # chunked prefill does not support all pooling
64
65
    err_msg = "pooling_task must be one of.+"
    with pytest.raises(ValueError, match=err_msg):
66
        llm.encode(prompts, pooling_task="token_classify", use_tqdm=False)
67
68
69
70
71
72


def test_score_api(llm: LLM):
    err_msg = "Score API is only enabled for num_labels == 1."
    with pytest.raises(ValueError, match=err_msg):
        llm.score("ping", "pong", use_tqdm=False)