pooling_params.py 8.04 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from copy import deepcopy
5
from typing import Any
6

7
import msgspec
8

9
from vllm.config import ModelConfig, PoolerConfig
10
from vllm.sampling_params import RequestOutputKind
11
from vllm.tasks import PoolingTask
12

13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class LateInteractionParams(
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    array_like=True,
):  # type: ignore[call-arg]
    """Metadata for worker-side late-interaction scoring.

    Attributes:
        mode:
            - "cache_query": cache query token embeddings
            - "score_doc": score a document against a cached query.
        query_key: stable key used for both DP routing and worker cache lookup.
        query_uses: expected number of document requests
    """

    mode: str
    query_key: str
    query_uses: int | None = None


34
class PoolingParams(
35
36
37
38
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    array_like=True,
):  # type: ignore[call-arg]
39
    """API parameters for pooling models.
40
41

    Attributes:
42
43
        use_activation: Whether to apply activation function to the pooler outputs.
            `None` uses the pooler's default, which is `True` in most cases.
44
        dimensions: Reduce the dimensions of embeddings
45
            if model support matryoshka representation.
46
    """
47
48

    # --8<-- [start:common-pooling-params]
49
    use_activation: bool | None = None
50
    # --8<-- [end:common-pooling-params]
51

52
    ## for embeddings models
53
    # --8<-- [start:embed-pooling-params]
54
    dimensions: int | None = None
55
    # --8<-- [end:embed-pooling-params]
56

57
    ## for classification, scoring and rerank
58
59
    # --8<-- [start:classify-pooling-params]
    # --8<-- [end:classify-pooling-params]
60

61
    ## for step pooling models
62
63
    step_tag_id: int | None = None
    returned_token_ids: list[int] | None = None
64

65
    ## Internal use only
66
    task: PoolingTask | None = None
67
    requires_token_ids: bool = False
68
    skip_reading_prefix_cache: bool | None = None
69
    late_interaction_params: LateInteractionParams | None = None
70
    extra_kwargs: dict[str, Any] | None = None
71
72
73
74
    output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY

    @property
    def all_parameters(self) -> list[str]:
75
        return ["dimensions", "use_activation"]
76
77
78
79

    @property
    def valid_parameters(self):
        return {
80
            "embed": ["dimensions", "use_activation"],
81
82
            "classify": ["use_activation"],
            "score": ["use_activation"],
83
            "token_embed": ["dimensions", "use_activation"],
84
            "token_classify": ["use_activation"],
85
86
        }

87
88
    def clone(self) -> "PoolingParams":
        """Returns a deep copy of the PoolingParams instance."""
89
90
        return deepcopy(self)

91
    def verify(self, model_config: ModelConfig) -> None:
92
93
94
        # plugin task uses io_processor.parse_request to verify inputs,
        # skipping PoolingParams verify
        if self.task == "plugin":
95
96
            if self.skip_reading_prefix_cache is None:
                self.skip_reading_prefix_cache = True
97
98
            return

99
100
101
        # NOTE: Task validation needs to done against the model instance,
        # which is not available in model config. So, it's not included
        # in this method
102
103
104
105
        self._merge_default_parameters(model_config)
        self._set_default_parameters(model_config)
        self._verify_valid_parameters()

106
    def _merge_default_parameters(self, model_config: ModelConfig) -> None:
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        pooler_config = model_config.pooler_config
        if pooler_config is None:
            return

        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]

        for k in valid_parameters:
            if getattr(pooler_config, k, None) is None:
                continue

            if getattr(self, k, None) is None:
                setattr(self, k, getattr(pooler_config, k))

121
122
123
124
125
126
127
128
129
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of all pooling may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            if self.task in ["token_embed", "token_classify"]:
                self.skip_reading_prefix_cache = True
            else:
                self.skip_reading_prefix_cache = False

130
131
132
        self._verify_step_pooling(pooler_config, valid_parameters)

    def _verify_step_pooling(
133
134
135
        self,
        pooler_config: PoolerConfig,
        valid_parameters: list[str],
136
137
    ):
        step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
138
        if pooler_config.tok_pooling_type != "STEP":
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
            invalid_parameters = []
            for k in step_pooling_parameters:
                if getattr(self, k, None) is not None:
                    invalid_parameters.append(k)

            if invalid_parameters:
                raise ValueError(
                    f"Task {self.task} only supports {valid_parameters} "
                    f"parameters, does not support "
                    f"{invalid_parameters} parameters"
                )
        else:
            for k in step_pooling_parameters:
                if getattr(pooler_config, k, None) is None:
                    continue

                if getattr(self, k, None) is None:
                    setattr(self, k, getattr(pooler_config, k))

158
    def _set_default_parameters(self, model_config: ModelConfig):
159
        if self.task in ["embed", "token_embed"]:
160
161
            if self.use_activation is None:
                self.use_activation = True
162

163
            if self.dimensions is not None:
164
                if not model_config.is_matryoshka:
165
                    raise ValueError(
166
                        f'Model "{model_config.served_model_name}" does not '
167
168
                        f"support matryoshka representation, "
                        f"changing output dimensions will lead to poor results."
169
170
171
172
173
174
                    )

                mds = model_config.matryoshka_dimensions
                if mds is not None:
                    if self.dimensions not in mds:
                        raise ValueError(
175
                            f"Model {model_config.served_model_name!r} "
176
177
178
179
                            f"only supports {str(mds)} matryoshka dimensions, "
                            f"use other output dimensions will "
                            f"lead to poor results."
                        )
180
181
182
                elif self.dimensions < 1:
                    raise ValueError("Dimensions must be greater than 0")

183
        elif self.task in ["classify", "score", "token_classify"]:
184
185
            if self.use_activation is None:
                self.use_activation = True
186
        else:
187
            raise ValueError(f"Unknown pooling task: {self.task!r}")
188
189
190
191
192
193
194
195
196
197
198
199
200
201

    def _verify_valid_parameters(self):
        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]
        invalid_parameters = []
        for k in self.all_parameters:
            if k in valid_parameters:
                continue

            if getattr(self, k, None) is not None:
                invalid_parameters.append(k)

        if invalid_parameters:
            raise ValueError(
202
                f"Task {self.task!r} only supports {valid_parameters} "
203
                f"parameters, does not support "
204
205
                f"{invalid_parameters} parameters"
            )
206
207

    def __repr__(self) -> str:
208
209
210
211
        return (
            f"PoolingParams("
            f"task={self.task}, "
            f"dimensions={self.dimensions}, "
212
            f"use_activation={self.use_activation}, "
213
214
215
            f"step_tag_id={self.step_tag_id}, "
            f"returned_token_ids={self.returned_token_ids}, "
            f"requires_token_ids={self.requires_token_ids}, "
216
            f"skip_reading_prefix_cache={self.skip_reading_prefix_cache}, "
217
            f"late_interaction_params={self.late_interaction_params}, "
218
219
            f"extra_kwargs={self.extra_kwargs})"
        )
220
221

    def __post_init__(self) -> None:
222
        assert self.output_kind == RequestOutputKind.FINAL_ONLY, (
223
            "For pooling output_kind has to be FINAL_ONLY"
224
        )