pooling_params.py 8.08 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from copy import deepcopy
5
from typing import Annotated, Any
6

7
import msgspec
8

9
from vllm.config import ModelConfig, PoolerConfig
10
from vllm.sampling_params import RequestOutputKind
11
from vllm.tasks import PoolingTask
12

13
14

class PoolingParams(
15
16
17
18
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    array_like=True,
):  # type: ignore[call-arg]
19
    """API parameters for pooling models.
20
21

    Attributes:
22
23
24
        truncate_prompt_tokens: Controls prompt truncation.
            Set to -1 to use the model's default truncation size.
            Set to k to keep only the last k tokens (left truncation).
25
            Set to None to disable truncation.
26
27
        use_activation: Whether to apply activation function to the pooler outputs.
            `None` uses the pooler's default, which is `True` in most cases.
28
        dimensions: Reduce the dimensions of embeddings
29
            if model support matryoshka representation.
30
    """
31
32

    # --8<-- [start:common-pooling-params]
33
    truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
34
    use_activation: bool | None = None
35
    # --8<-- [end:common-pooling-params]
36

37
    ## for embeddings models
38
    # --8<-- [start:embed-pooling-params]
39
    dimensions: int | None = None
40
    # --8<-- [end:embed-pooling-params]
41

42
    ## for classification, scoring and rerank
43
44
    # --8<-- [start:classify-pooling-params]
    # --8<-- [end:classify-pooling-params]
45

46
    ## for step pooling models
47
48
    step_tag_id: int | None = None
    returned_token_ids: list[int] | None = None
49

50
    ## Internal use only
51
    task: PoolingTask | None = None
52
    requires_token_ids: bool = False
53
    skip_reading_prefix_cache: bool | None = None
54
    extra_kwargs: dict[str, Any] | None = None
55
56
57
58
    output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY

    @property
    def all_parameters(self) -> list[str]:
59
        return ["dimensions", "use_activation"]
60
61
62
63

    @property
    def valid_parameters(self):
        return {
64
            "embed": ["dimensions", "use_activation"],
65
66
            "classify": ["use_activation"],
            "score": ["use_activation"],
67
            "token_embed": ["dimensions", "use_activation"],
68
            "token_classify": ["use_activation"],
69
70
        }

71
72
    def clone(self) -> "PoolingParams":
        """Returns a deep copy of the PoolingParams instance."""
73
74
        return deepcopy(self)

75
    def verify(
76
        self, task: PoolingTask, model_config: "ModelConfig | None" = None
77
    ) -> None:
78
79
80
81
82
83
        if self.task is None:
            self.task = task
        elif self.task != task:
            msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
            raise ValueError(msg)

84
85
86
        # plugin task uses io_processor.parse_request to verify inputs,
        # skipping PoolingParams verify
        if self.task == "plugin":
87
88
            if self.skip_reading_prefix_cache is None:
                self.skip_reading_prefix_cache = True
89
90
            return

91
92
93
        # NOTE: Task validation needs to done against the model instance,
        # which is not available in model config. So, it's not included
        # in this method
94
95
96
97
        self._merge_default_parameters(model_config)
        self._set_default_parameters(model_config)
        self._verify_valid_parameters()

98
    def _merge_default_parameters(
99
        self, model_config: "ModelConfig | None" = None
100
    ) -> None:
101
102
        if model_config is None:
            return
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
        pooler_config = model_config.pooler_config
        if pooler_config is None:
            return

        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]

        for k in valid_parameters:
            if getattr(pooler_config, k, None) is None:
                continue

            if getattr(self, k, None) is None:
                setattr(self, k, getattr(pooler_config, k))

118
119
120
121
122
123
124
125
126
        if self.skip_reading_prefix_cache is None:
            # If prefix caching is enabled,
            # the output of all pooling may less than n_prompt_tokens,
            # we need to skip reading cache at this request.
            if self.task in ["token_embed", "token_classify"]:
                self.skip_reading_prefix_cache = True
            else:
                self.skip_reading_prefix_cache = False

127
128
129
130
131
132
        self._verify_step_pooling(pooler_config, valid_parameters)

    def _verify_step_pooling(
        self, pooler_config: "PoolerConfig", valid_parameters: list[str]
    ):
        step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
133
        if pooler_config.tok_pooling_type != "STEP":
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            invalid_parameters = []
            for k in step_pooling_parameters:
                if getattr(self, k, None) is not None:
                    invalid_parameters.append(k)

            if invalid_parameters:
                raise ValueError(
                    f"Task {self.task} only supports {valid_parameters} "
                    f"parameters, does not support "
                    f"{invalid_parameters} parameters"
                )
        else:
            for k in step_pooling_parameters:
                if getattr(pooler_config, k, None) is None:
                    continue

                if getattr(self, k, None) is None:
                    setattr(self, k, getattr(pooler_config, k))

153
    def _set_default_parameters(self, model_config: "ModelConfig | None"):
154
        if self.task in ["embed", "token_embed"]:
155
156
            if self.use_activation is None:
                self.use_activation = True
157
158
159

            if self.dimensions is not None and model_config is not None:
                if not model_config.is_matryoshka:
160
                    raise ValueError(
161
                        f'Model "{model_config.served_model_name}" does not '
162
163
                        f"support matryoshka representation, "
                        f"changing output dimensions will lead to poor results."
164
165
166
167
168
169
170
                    )

                mds = model_config.matryoshka_dimensions
                if mds is not None:
                    if self.dimensions not in mds:
                        raise ValueError(
                            f'Model "{model_config.served_model_name}" '
171
172
173
174
                            f"only supports {str(mds)} matryoshka dimensions, "
                            f"use other output dimensions will "
                            f"lead to poor results."
                        )
175
176
177
                elif self.dimensions < 1:
                    raise ValueError("Dimensions must be greater than 0")

178
        elif self.task in ["classify", "score", "token_classify"]:
179
180
            if self.use_activation is None:
                self.use_activation = True
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        else:
            raise ValueError(f"Unknown pooling task: {self.task}")

    def _verify_valid_parameters(self):
        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]
        invalid_parameters = []
        for k in self.all_parameters:
            if k in valid_parameters:
                continue

            if getattr(self, k, None) is not None:
                invalid_parameters.append(k)

        if invalid_parameters:
            raise ValueError(
                f"Task {self.task} only supports {valid_parameters} "
                f"parameters, does not support "
199
200
                f"{invalid_parameters} parameters"
            )
201
202

    def __repr__(self) -> str:
203
204
205
206
        return (
            f"PoolingParams("
            f"task={self.task}, "
            f"dimensions={self.dimensions}, "
207
            f"use_activation={self.use_activation}, "
208
209
210
            f"step_tag_id={self.step_tag_id}, "
            f"returned_token_ids={self.returned_token_ids}, "
            f"requires_token_ids={self.requires_token_ids}, "
211
            f"skip_reading_prefix_cache={self.skip_reading_prefix_cache}, "
212
            f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
213
214
            f"extra_kwargs={self.extra_kwargs})"
        )
215
216

    def __post_init__(self) -> None:
217
        assert self.output_kind == RequestOutputKind.FINAL_ONLY, (
218
            "For pooling output_kind has to be FINAL_ONLY"
219
        )