pooling_params.py 6.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from copy import deepcopy
5
from typing import TYPE_CHECKING, Annotated, Any, Optional
6

7
import msgspec
8

9
from vllm.sampling_params import RequestOutputKind
10
from vllm.tasks import PoolingTask
11

12
13
14
if TYPE_CHECKING:
    from vllm.config import ModelConfig

15
16

class PoolingParams(
17
18
19
20
    msgspec.Struct,
    omit_defaults=True,  # type: ignore[call-arg]
    array_like=True,
):  # type: ignore[call-arg]
21
    """API parameters for pooling models.
22
23

    Attributes:
24
25
26
        truncate_prompt_tokens: Controls prompt truncation.
            Set to -1 to use the model's default truncation size.
            Set to k to keep only the last k tokens (left truncation).
27
            Set to None to disable truncation.
28
        normalize: Whether to normalize the embeddings outputs.
29
        dimensions: Reduce the dimensions of embeddings
30
            if model support matryoshka representation.
31
        activation: Whether to apply activation function to
32
            the classification outputs.
33
        softmax: Whether to apply softmax to the reward outputs.
34
    """
35
36

    # --8<-- [start:common-pooling-params]
37
    truncate_prompt_tokens: Optional[Annotated[int, msgspec.Meta(ge=-1)]] = None
38
    # --8<-- [end:common-pooling-params]
39

40
    ## for embeddings models
41
    # --8<-- [start:embedding-pooling-params]
42
    dimensions: Optional[int] = None
43
    normalize: Optional[bool] = None
44
    # --8<-- [end:embedding-pooling-params]
45

46
47
    ## for classification, scoring and rerank
    # --8<-- [start:classification-pooling-params]
48
    activation: Optional[bool] = None
49
    # --8<-- [end:classification-pooling-params]
50
51
52
53
54

    ## for reward models
    softmax: Optional[bool] = None
    step_tag_id: Optional[int] = None
    returned_token_ids: Optional[list[int]] = None
55

56
    task: Optional[PoolingTask] = None
57
58
    """Internal use only."""

59
60
    requires_token_ids: bool = False
    """Internal use only."""
61

62
63
64
    extra_kwargs: Optional[dict[str, Any]] = None
    """Internal use only."""

65
66
67
68
69
    output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY

    @property
    def all_parameters(self) -> list[str]:
        return [
70
71
72
73
74
75
            "dimensions",
            "normalize",
            "activation",
            "softmax",
            "step_tag_id",
            "returned_token_ids",
76
77
78
79
80
81
82
83
84
85
86
        ]

    @property
    def valid_parameters(self):
        return {
            "embed": ["dimensions", "normalize"],
            "classify": ["activation"],
            "score": ["activation"],
            "encode": ["softmax", "step_tag_id", "returned_token_ids"],
        }

87
88
    def clone(self) -> "PoolingParams":
        """Returns a deep copy of the PoolingParams instance."""
89
90
        return deepcopy(self)

91
92
93
    def verify(
        self, task: PoolingTask, model_config: Optional["ModelConfig"] = None
    ) -> None:
94
95
96
97
98
99
100
101
102
103
        if self.task is None:
            self.task = task
        elif self.task != task:
            msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
            raise ValueError(msg)

        # NOTE: Task validation needs to done against the model instance,
        # which is not available in model config. So, it's not included
        # in this method

104
105
106
107
        self._merge_default_parameters(model_config)
        self._set_default_parameters(model_config)
        self._verify_valid_parameters()

108
109
110
    def _merge_default_parameters(
        self, model_config: Optional["ModelConfig"] = None
    ) -> None:
111
112
        if model_config is None:
            return
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        pooler_config = model_config.pooler_config
        if pooler_config is None:
            return

        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]

        for k in valid_parameters:
            if getattr(pooler_config, k, None) is None:
                continue

            if getattr(self, k, None) is None:
                setattr(self, k, getattr(pooler_config, k))

    def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
        if self.task == "embed":
            if self.normalize is None:
                self.normalize = True

            if self.dimensions is not None and model_config is not None:
                if not model_config.is_matryoshka:
135
                    raise ValueError(
136
                        f'Model "{model_config.served_model_name}" does not '
137
138
                        f"support matryoshka representation, "
                        f"changing output dimensions will lead to poor results."
139
140
141
142
143
144
145
                    )

                mds = model_config.matryoshka_dimensions
                if mds is not None:
                    if self.dimensions not in mds:
                        raise ValueError(
                            f'Model "{model_config.served_model_name}" '
146
147
148
149
                            f"only supports {str(mds)} matryoshka dimensions, "
                            f"use other output dimensions will "
                            f"lead to poor results."
                        )
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
                elif self.dimensions < 1:
                    raise ValueError("Dimensions must be greater than 0")

        elif self.task in ["classify", "score"]:
            if self.activation is None:
                self.activation = True

        elif self.task == "encode":
            if self.softmax is None:
                self.softmax = True
        else:
            raise ValueError(f"Unknown pooling task: {self.task}")

    def _verify_valid_parameters(self):
        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]
        invalid_parameters = []
        for k in self.all_parameters:
            if k in valid_parameters:
                continue

            if getattr(self, k, None) is not None:
                invalid_parameters.append(k)

        if invalid_parameters:
            raise ValueError(
                f"Task {self.task} only supports {valid_parameters} "
                f"parameters, does not support "
178
179
                f"{invalid_parameters} parameters"
            )
180
181

    def __repr__(self) -> str:
182
183
184
185
186
187
188
189
190
191
192
193
        return (
            f"PoolingParams("
            f"task={self.task}, "
            f"normalize={self.normalize}, "
            f"dimensions={self.dimensions}, "
            f"activation={self.activation}, "
            f"softmax={self.softmax}, "
            f"step_tag_id={self.step_tag_id}, "
            f"returned_token_ids={self.returned_token_ids}, "
            f"requires_token_ids={self.requires_token_ids}, "
            f"extra_kwargs={self.extra_kwargs})"
        )
194
195

    def __post_init__(self) -> None:
196
        assert self.output_kind == RequestOutputKind.FINAL_ONLY, (
197
            "For pooling output_kind has to be FINAL_ONLY"
198
        )