"tests/vscode:/vscode.git/clone" did not exist on "08f8a5627e6306312202d5c0fda8b7255d2f27fc"
pooling_params.py 8.58 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from copy import deepcopy
5
from typing import Annotated, Any, Optional
6

7
import msgspec
8

9
10
from vllm.config import ModelConfig, PoolerConfig
from vllm.config.pooler import get_use_activation
11
from vllm.sampling_params import RequestOutputKind
12
from vllm.tasks import PoolingTask
13

14
15

class PoolingParams(
16
17
18
19
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    array_like=True,
):  # type: ignore[call-arg]
20
    """API parameters for pooling models.
21
22

    Attributes:
23
24
25
        truncate_prompt_tokens: Controls prompt truncation.
            Set to -1 to use the model's default truncation size.
            Set to k to keep only the last k tokens (left truncation).
26
            Set to None to disable truncation.
27
        dimensions: Reduce the dimensions of embeddings
28
            if model support matryoshka representation.
29
30
31
32
        normalize: Whether to normalize the embeddings outputs.
        softmax: softmax will be deprecated, please use use_activation instead.
        activation: activation will be deprecated, please use use_activation instead.
        use_activation: Whether to apply activation function to
33
            the classification outputs.
34
    """
35
36

    # --8<-- [start:common-pooling-params]
37
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
38
    # --8<-- [end:common-pooling-params]
39

40
    ## for embeddings models
41
    # --8<-- [start:embedding-pooling-params]
42
43
    dimensions: int | None = None
    normalize: bool | None = None
44
    # --8<-- [end:embedding-pooling-params]
45

46
47
    ## for classification, scoring and rerank
    # --8<-- [start:classification-pooling-params]
48
    softmax: bool | None = None
49
    activation: bool | None = None
50
    use_activation: bool | None = None
51
    # --8<-- [end:classification-pooling-params]
52

53
    ## for step pooling models
54
55
    step_tag_id: int | None = None
    returned_token_ids: list[int] | None = None
56

57
    ## Internal use only
58
    task: PoolingTask | None = None
59
    requires_token_ids: bool = False
60
    skip_reading_prefix_cache: bool | None = None
61
    extra_kwargs: dict[str, Any] | None = None
62
63
64
65
    output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY

    @property
    def all_parameters(self) -> list[str]:
66
        return ["dimensions", "normalize", "use_activation"]
67
68
69
70
71

    @property
    def valid_parameters(self):
        return {
            "embed": ["dimensions", "normalize"],
72
73
            "classify": ["use_activation"],
            "score": ["use_activation"],
74
            "token_embed": ["dimensions", "normalize"],
75
            "token_classify": ["use_activation"],
76
77
        }

78
79
    def clone(self) -> "PoolingParams":
        """Returns a deep copy of the PoolingParams instance."""
80
81
        return deepcopy(self)

82
83
84
    def verify(
        self, task: PoolingTask, model_config: Optional["ModelConfig"] = None
    ) -> None:
85
86
87
88
89
90
        if self.task is None:
            self.task = task
        elif self.task != task:
            msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
            raise ValueError(msg)

91
92
93
        # raise deprecated warning for softmax and activation
        self.use_activation = get_use_activation(self)

94
95
96
        # plugin task uses io_processor.parse_request to verify inputs,
        # skipping PoolingParams verify
        if self.task == "plugin":
97
98
            if self.skip_reading_prefix_cache is None:
                self.skip_reading_prefix_cache = True
99
100
            return

101
102
103
        # NOTE: Task validation needs to done against the model instance,
        # which is not available in model config. So, it's not included
        # in this method
104
105
106
107
        self._merge_default_parameters(model_config)
        self._set_default_parameters(model_config)
        self._verify_valid_parameters()

108
109
110
    def _merge_default_parameters(
        self, model_config: Optional["ModelConfig"] = None
    ) -> None:
111
112
        if model_config is None:
            return
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
        pooler_config = model_config.pooler_config
        if pooler_config is None:
            return

        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]

        for k in valid_parameters:
            if getattr(pooler_config, k, None) is None:
                continue

            if getattr(self, k, None) is None:
                setattr(self, k, getattr(pooler_config, k))

128
129
130
131
132
133
134
135
136
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of all pooling may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            if self.task in ["token_embed", "token_classify"]:
                self.skip_reading_prefix_cache = True
            else:
                self.skip_reading_prefix_cache = False

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
        self._verify_step_pooling(pooler_config, valid_parameters)

    def _verify_step_pooling(
        self, pooler_config: "PoolerConfig", valid_parameters: list[str]
    ):
        step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
        if pooler_config.pooling_type != "STEP":
            invalid_parameters = []
            for k in step_pooling_parameters:
                if getattr(self, k, None) is not None:
                    invalid_parameters.append(k)

            if invalid_parameters:
                raise ValueError(
                    f"Task {self.task} only supports {valid_parameters} "
                    f"parameters, does not support "
                    f"{invalid_parameters} parameters"
                )
        else:
            for k in step_pooling_parameters:
                if getattr(pooler_config, k, None) is None:
                    continue

                if getattr(self, k, None) is None:
                    setattr(self, k, getattr(pooler_config, k))

163
    def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
164
        if self.task in ["embed", "token_embed"]:
165
166
167
168
169
            if self.normalize is None:
                self.normalize = True

            if self.dimensions is not None and model_config is not None:
                if not model_config.is_matryoshka:
170
                    raise ValueError(
171
                        f'Model "{model_config.served_model_name}" does not '
172
173
                        f"support matryoshka representation, "
                        f"changing output dimensions will lead to poor results."
174
175
176
177
178
179
180
                    )

                mds = model_config.matryoshka_dimensions
                if mds is not None:
                    if self.dimensions not in mds:
                        raise ValueError(
                            f'Model "{model_config.served_model_name}" '
181
182
183
184
                            f"only supports {str(mds)} matryoshka dimensions, "
                            f"use other output dimensions will "
                            f"lead to poor results."
                        )
185
186
187
                elif self.dimensions < 1:
                    raise ValueError("Dimensions must be greater than 0")

188
        elif self.task in ["classify", "score", "token_classify"]:
189
190
            if self.use_activation is None:
                self.use_activation = True
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
        else:
            raise ValueError(f"Unknown pooling task: {self.task}")

    def _verify_valid_parameters(self):
        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]
        invalid_parameters = []
        for k in self.all_parameters:
            if k in valid_parameters:
                continue

            if getattr(self, k, None) is not None:
                invalid_parameters.append(k)

        if invalid_parameters:
            raise ValueError(
                f"Task {self.task} only supports {valid_parameters} "
                f"parameters, does not support "
209
210
                f"{invalid_parameters} parameters"
            )
211
212

    def __repr__(self) -> str:
213
214
215
216
217
        return (
            f"PoolingParams("
            f"task={self.task}, "
            f"normalize={self.normalize}, "
            f"dimensions={self.dimensions}, "
218
            f"use_activation={self.use_activation}, "
219
220
221
            f"step_tag_id={self.step_tag_id}, "
            f"returned_token_ids={self.returned_token_ids}, "
            f"requires_token_ids={self.requires_token_ids}, "
222
            f"skip_reading_prefix_cache={self.skip_reading_prefix_cache}, "
223
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
224
225
            f"extra_kwargs={self.extra_kwargs})"
        )
226
227

    def __post_init__(self) -> None:
228
        assert self.output_kind == RequestOutputKind.FINAL_ONLY, (
229
            "For pooling output_kind has to be FINAL_ONLY"
230
        )