test_pooling_params.py 4.86 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
from dataclasses import dataclass

5
6
7
8
import pytest

from tests.models.utils import EmbedModelInfo
from vllm import PoolingParams
9
from vllm.config import ModelConfig, PoolerConfig
10
11
12

EMBEDDING_MODELS = [
    EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
13
14
15
16
17
    EmbedModelInfo(
        "Snowflake/snowflake-arctic-embed-m-v1.5",
        is_matryoshka=True,
        matryoshka_dimensions=[256],
    ),
18
19
]

20
classify_parameters = ["use_activation"]
21
embed_parameters = ["dimensions", "use_activation"]
22
23
24
25
26
27
28
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]


@dataclass()
class MockModelConfig:
    pooler_config: PoolerConfig

29
30
31

def test_embed():
    task = "embed"
32
    model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
33

34
35
    pooling_params = PoolingParams(task=task, use_activation=None)
    pooling_params.verify(model_config)
36

37
38
    pooling_params = PoolingParams(task=task, use_activation=True)
    pooling_params.verify(model_config)
39

40
41
    pooling_params = PoolingParams(task=task, use_activation=False)
    pooling_params.verify(model_config)
42

43
    invalid_parameters = classify_parameters + step_pooling_parameters
44
    for p in set(invalid_parameters) - set(embed_parameters):
45
        with pytest.raises(ValueError):
46
47
            pooling_params = PoolingParams(task=task, **{p: True})
            pooling_params.verify(model_config)
48
49
50
51
52
53
54
55
56
57
58
59
60
61


@pytest.mark.parametrize("model_info", EMBEDDING_MODELS)
def test_embed_dimensions(model_info: EmbedModelInfo):
    task = "embed"
    model_config = ModelConfig(
        model_info.name,
        tokenizer=model_info.name,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="float16",
    )

62
63
    pooling_params = PoolingParams(task=task, dimensions=None)
    pooling_params.verify(model_config)
64
65

    with pytest.raises(ValueError):
66
67
        pooling_params = PoolingParams(task=task, dimensions=1)
        pooling_params.verify(model_config)
68
69
70

    if model_info.is_matryoshka:
        assert model_info.matryoshka_dimensions is not None
71
72
73
74
        pooling_params = PoolingParams(
            task=task, dimensions=model_info.matryoshka_dimensions[0]
        )
        pooling_params.verify(model_config)
75
76


77
@pytest.mark.parametrize("task", ["classify"])
78
def test_classify(task):
79
    model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
80

81
82
    pooling_params = PoolingParams(task=task, use_activation=None)
    pooling_params.verify(model_config)
83

84
85
    pooling_params = PoolingParams(task=task, use_activation=True)
    pooling_params.verify(model_config)
86

87
88
    pooling_params = PoolingParams(task=task, use_activation=False)
    pooling_params.verify(model_config)
89
90

    invalid_parameters = embed_parameters + step_pooling_parameters
91
    for p in set(invalid_parameters) - set(classify_parameters):
92
        with pytest.raises(ValueError):
93
94
            pooling_params = PoolingParams(task=task, **{p: True})
            pooling_params.verify(model_config)
95
96
97
98
99
100


@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
def test_token_embed(pooling_type: str):
    task = "token_embed"
    model_config = MockModelConfig(
101
        pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
102
103
    )

104
105
    pooling_params = PoolingParams(task=task, use_activation=None)
    pooling_params.verify(model_config)
106

107
108
    pooling_params = PoolingParams(task=task, use_activation=True)
    pooling_params.verify(model_config)
109

110
111
    pooling_params = PoolingParams(task=task, use_activation=False)
    pooling_params.verify(model_config)
112
113
114
115

    invalid_parameters = classify_parameters
    if pooling_type != "STEP":
        invalid_parameters = classify_parameters + step_pooling_parameters
116

117
    for p in set(invalid_parameters) - set(embed_parameters):
118
        with pytest.raises(ValueError):
119
120
            pooling_params = PoolingParams(task=task, **{p: True})
            pooling_params.verify(model_config)
121
122


123
124
125
126
@pytest.mark.parametrize("pooling_type", ["ALL", "STEP"])
def test_token_classify(pooling_type: str):
    task = "token_classify"
    model_config = MockModelConfig(
127
        pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
128
    )
129

130
131
    pooling_params = PoolingParams(task=task, use_activation=None)
    pooling_params.verify(model_config)
132

133
134
    pooling_params = PoolingParams(task=task, use_activation=True)
    pooling_params.verify(model_config)
135

136
137
    pooling_params = PoolingParams(task=task, use_activation=False)
    pooling_params.verify(model_config)
138

139
140
141
    invalid_parameters = embed_parameters
    if pooling_type != "STEP":
        invalid_parameters = embed_parameters + step_pooling_parameters
142

143
    for p in set(invalid_parameters) - set(classify_parameters):
144
        with pytest.raises(ValueError):
145
146
            pooling_params = PoolingParams(task=task, **{p: True})
            pooling_params.verify(model_config)