"vllm/vscode:/vscode.git/clone" did not exist on "a2f0ce42dfd7bcce487a54bd61bf37b853a5bd3d"
pooling_params.py 8.19 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from copy import deepcopy
5
from typing import Any
6

7
import msgspec
8

9
from vllm.config import ModelConfig, PoolerConfig
10
from vllm.sampling_params import RequestOutputKind
11
from vllm.tasks import PoolingTask
12

13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class LateInteractionParams(
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    array_like=True,
):  # type: ignore[call-arg]
    """Metadata for worker-side late-interaction scoring.

    Attributes:
        mode:
            - "cache_query": cache query token embeddings
            - "score_doc": score a document against a cached query.
        query_key: stable key used for both DP routing and worker cache lookup.
        query_uses: expected number of document requests
    """

    mode: str
    query_key: str
    query_uses: int | None = None


34
class PoolingParams(
35
36
37
38
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    array_like=True,
):  # type: ignore[call-arg]
39
    """API parameters for pooling models.
40
41

    Attributes:
42
43
        use_activation: Whether to apply activation function to the pooler outputs.
            `None` uses the pooler's default, which is `True` in most cases.
44
        dimensions: Reduce the dimensions of embeddings
45
            if model support matryoshka representation.
46
    """
47
48

    # --8<-- [start:common-pooling-params]
49
    use_activation: bool | None = None
50
    # --8<-- [end:common-pooling-params]
51

52
    ## for embeddings models
53
    # --8<-- [start:embed-pooling-params]
54
    dimensions: int | None = None
55
    # --8<-- [end:embed-pooling-params]
56

57
    ## for classification, scoring and rerank
58
59
    # --8<-- [start:classify-pooling-params]
    # --8<-- [end:classify-pooling-params]
60

61
    ## for step pooling models
62
63
    step_tag_id: int | None = None
    returned_token_ids: list[int] | None = None
64

65
    ## Internal use only
66
    task: PoolingTask | None = None
67
    requires_token_ids: bool = False
68
    skip_reading_prefix_cache: bool | None = None
69
    late_interaction_params: LateInteractionParams | None = None
70
    extra_kwargs: dict[str, Any] | None = None
71
72
73
74
    output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY

    @property
    def all_parameters(self) -> list[str]:
75
        return ["dimensions", "use_activation"]
76
77
78
79

    @property
    def valid_parameters(self):
        return {
80
            "embed": ["dimensions", "use_activation"],
81
82
            "classify": ["use_activation"],
            "score": ["use_activation"],
83
            "token_embed": ["dimensions", "use_activation"],
84
            "token_classify": ["use_activation"],
85
86
        }

87
88
    def clone(self) -> "PoolingParams":
        """Returns a deep copy of the PoolingParams instance."""
89
90
        return deepcopy(self)

91
    def verify(self, model_config: ModelConfig) -> None:
92
93
94
        # plugin task uses io_processor.parse_request to verify inputs,
        # skipping PoolingParams verify
        if self.task == "plugin":
95
96
            if self.skip_reading_prefix_cache is None:
                self.skip_reading_prefix_cache = True
97
98
            return

99
100
101
102
        # skipping verify, let plugins configure and validate pooling params
        if self.task not in self.valid_parameters:
            return

103
104
105
        # NOTE: Task validation needs to done against the model instance,
        # which is not available in model config. So, it's not included
        # in this method
106
107
108
109
        self._merge_default_parameters(model_config)
        self._set_default_parameters(model_config)
        self._verify_valid_parameters()

110
    def _merge_default_parameters(self, model_config: ModelConfig) -> None:
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        pooler_config = model_config.pooler_config
        if pooler_config is None:
            return

        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]

        for k in valid_parameters:
            if getattr(pooler_config, k, None) is None:
                continue

            if getattr(self, k, None) is None:
                setattr(self, k, getattr(pooler_config, k))

125
126
127
128
129
130
131
132
133
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of all pooling may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            if self.task in ["token_embed", "token_classify"]:
                self.skip_reading_prefix_cache = True
            else:
                self.skip_reading_prefix_cache = False

134
135
136
        self._verify_step_pooling(pooler_config, valid_parameters)

    def _verify_step_pooling(
137
138
139
        self,
        pooler_config: PoolerConfig,
        valid_parameters: list[str],
140
141
    ):
        step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
142
        if pooler_config.tok_pooling_type != "STEP":
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
            invalid_parameters = []
            for k in step_pooling_parameters:
                if getattr(self, k, None) is not None:
                    invalid_parameters.append(k)

            if invalid_parameters:
                raise ValueError(
                    f"Task {self.task} only supports {valid_parameters} "
                    f"parameters, does not support "
                    f"{invalid_parameters} parameters"
                )
        else:
            for k in step_pooling_parameters:
                if getattr(pooler_config, k, None) is None:
                    continue

                if getattr(self, k, None) is None:
                    setattr(self, k, getattr(pooler_config, k))

162
    def _set_default_parameters(self, model_config: ModelConfig):
163
        if self.task in ["embed", "token_embed"]:
164
165
            if self.use_activation is None:
                self.use_activation = True
166

167
            if self.dimensions is not None:
168
                if not model_config.is_matryoshka:
169
                    raise ValueError(
170
                        f'Model "{model_config.served_model_name}" does not '
171
172
                        f"support matryoshka representation, "
                        f"changing output dimensions will lead to poor results."
173
174
175
176
177
178
                    )

                mds = model_config.matryoshka_dimensions
                if mds is not None:
                    if self.dimensions not in mds:
                        raise ValueError(
179
                            f"Model {model_config.served_model_name!r} "
180
181
182
183
                            f"only supports {str(mds)} matryoshka dimensions, "
                            f"use other output dimensions will "
                            f"lead to poor results."
                        )
184
185
186
                elif self.dimensions < 1:
                    raise ValueError("Dimensions must be greater than 0")

187
        elif self.task in ["classify", "score", "token_classify"]:
188
189
            if self.use_activation is None:
                self.use_activation = True
190
        else:
191
            raise ValueError(f"Unknown pooling task: {self.task!r}")
192
193
194
195
196
197
198
199
200
201
202
203
204
205

    def _verify_valid_parameters(self):
        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]
        invalid_parameters = []
        for k in self.all_parameters:
            if k in valid_parameters:
                continue

            if getattr(self, k, None) is not None:
                invalid_parameters.append(k)

        if invalid_parameters:
            raise ValueError(
206
                f"Task {self.task!r} only supports {valid_parameters} "
207
                f"parameters, does not support "
208
209
                f"{invalid_parameters} parameters"
            )
210
211

    def __repr__(self) -> str:
212
213
214
215
        return (
            f"PoolingParams("
            f"task={self.task}, "
            f"dimensions={self.dimensions}, "
216
            f"use_activation={self.use_activation}, "
217
218
219
            f"step_tag_id={self.step_tag_id}, "
            f"returned_token_ids={self.returned_token_ids}, "
            f"requires_token_ids={self.requires_token_ids}, "
220
            f"skip_reading_prefix_cache={self.skip_reading_prefix_cache}, "
221
            f"late_interaction_params={self.late_interaction_params}, "
222
223
            f"extra_kwargs={self.extra_kwargs})"
        )
224
225

    def __post_init__(self) -> None:
226
        assert self.output_kind == RequestOutputKind.FINAL_ONLY, (
227
            "For pooling output_kind has to be FINAL_ONLY"
228
        )