pooling_params.py 6.46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from copy import deepcopy
5
from typing import TYPE_CHECKING, Annotated, Any, Optional
6

7
import msgspec
8

9
from vllm.sampling_params import RequestOutputKind
10
from vllm.tasks import PoolingTask
11

12
13
14
if TYPE_CHECKING:
    from vllm.config import ModelConfig

15
16
17
18
19

class PoolingParams(
        msgspec.Struct,
        omit_defaults=True,  # type: ignore[call-arg]
        array_like=True):  # type: ignore[call-arg]
20
    """API parameters for pooling models.
21
22

    Attributes:
23
        normalize: Whether to normalize the embeddings outputs.
24
25
        dimensions: Reduce the dimensions of embeddings
                    if model support matryoshka representation.
26
27
28
        activation: Whether to apply activation function to
                    the classification outputs.
        softmax: Whether to apply softmax to the reward outputs.
29
    """
30
31
32
33
34
    truncate_prompt_tokens: Optional[Annotated[int,
                                               msgspec.Meta(ge=-1)]] = None
    """If set to -1, will use the truncation size supported by the model. If
    set to an integer k, will use only the last k tokens from the prompt
    (i.e., left truncation). If set to `None`, truncation is disabled."""
35

36
    ## for embeddings models
37
    dimensions: Optional[int] = None
38
    normalize: Optional[bool] = None
39

40
41
42
43
44
45
46
    ## for classification models
    activation: Optional[bool] = None

    ## for reward models
    softmax: Optional[bool] = None
    step_tag_id: Optional[int] = None
    returned_token_ids: Optional[list[int]] = None
47

48
    task: Optional[PoolingTask] = None
49
50
    """Internal use only."""

51
52
    requires_token_ids: bool = False
    """Internal use only."""
53

54
55
56
    extra_kwargs: Optional[dict[str, Any]] = None
    """Internal use only."""

57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY

    @property
    def all_parameters(self) -> list[str]:
        return [
            "dimensions", "normalize", "activation", "softmax", "step_tag_id",
            "returned_token_ids"
        ]

    @property
    def valid_parameters(self):
        return {
            "embed": ["dimensions", "normalize"],
            "classify": ["activation"],
            "score": ["activation"],
            "encode": ["softmax", "step_tag_id", "returned_token_ids"],
        }

75
76
    def clone(self) -> "PoolingParams":
        """Returns a deep copy of the PoolingParams instance."""
77
78
79
80
81
        return deepcopy(self)

    def verify(self,
               task: PoolingTask,
               model_config: Optional["ModelConfig"] = None) -> None:
82

83
84
85
86
87
88
89
90
91
92
        if self.task is None:
            self.task = task
        elif self.task != task:
            msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
            raise ValueError(msg)

        # NOTE: Task validation needs to done against the model instance,
        # which is not available in model config. So, it's not included
        # in this method

93
94
95
96
97
98
99
100
101
102
        self._merge_default_parameters(model_config)
        self._set_default_parameters(model_config)
        self._verify_valid_parameters()

    def _merge_default_parameters(self,
                                  model_config: Optional["ModelConfig"] = None
                                  ) -> None:

        if model_config is None:
            return
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
        pooler_config = model_config.pooler_config
        if pooler_config is None:
            return

        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]

        for k in valid_parameters:
            if getattr(pooler_config, k, None) is None:
                continue

            if getattr(self, k, None) is None:
                setattr(self, k, getattr(pooler_config, k))

    def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
        if self.task == "embed":
            if self.normalize is None:
                self.normalize = True

            if self.dimensions is not None and model_config is not None:
                if not model_config.is_matryoshka:
125
                    raise ValueError(
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
                        f'Model "{model_config.served_model_name}" does not '
                        f'support matryoshka representation, '
                        f'changing output dimensions will lead to poor results.'
                    )

                mds = model_config.matryoshka_dimensions
                if mds is not None:
                    if self.dimensions not in mds:
                        raise ValueError(
                            f'Model "{model_config.served_model_name}" '
                            f'only supports {str(mds)} matryoshka dimensions, '
                            f'use other output dimensions will '
                            f'lead to poor results.')
                elif self.dimensions < 1:
                    raise ValueError("Dimensions must be greater than 0")

        elif self.task in ["classify", "score"]:
            if self.activation is None:
                self.activation = True

        elif self.task == "encode":
            if self.softmax is None:
                self.softmax = True
        else:
            raise ValueError(f"Unknown pooling task: {self.task}")

    def _verify_valid_parameters(self):
        assert self.task is not None, "task must be set"
        valid_parameters = self.valid_parameters[self.task]
        invalid_parameters = []
        for k in self.all_parameters:
            if k in valid_parameters:
                continue

            if getattr(self, k, None) is not None:
                invalid_parameters.append(k)

        if invalid_parameters:
            raise ValueError(
                f"Task {self.task} only supports {valid_parameters} "
                f"parameters, does not support "
                f"{invalid_parameters} parameters")
168
169

    def __repr__(self) -> str:
170
171
        return (f"PoolingParams("
                f"task={self.task}, "
172
173
174
175
176
177
                f"normalize={self.normalize}, "
                f"dimensions={self.dimensions}, "
                f"activation={self.activation}, "
                f"softmax={self.softmax}, "
                f"step_tag_id={self.step_tag_id}, "
                f"returned_token_ids={self.returned_token_ids}, "
178
179
                f"requires_token_ids={self.requires_token_ids}, "
                f"extra_kwargs={self.extra_kwargs})")
180
181
182
183

    def __post_init__(self) -> None:
        assert self.output_kind == RequestOutputKind.FINAL_ONLY,\
            "For pooling output_kind has to be FINAL_ONLY"