pooler.py 7.61 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from typing import Any, Literal, get_args
5
6

from vllm.config.utils import config
7
from vllm.logger import init_logger
8
from vllm.tasks import PoolingTask
9
from vllm.utils.hashing import safe_hash
10
11

logger = init_logger(__name__)
12

13
14
15
16
17
SequencePoolingType = Literal["CLS", "LAST", "MEAN"]
SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType)

TokenPoolingType = Literal["ALL", "STEP"]
TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
18

19
20
21
22
23

@config
class PoolerConfig:
    """Controls the behavior of output pooling in pooling models."""

24
25
26
27
28
    task: PoolingTask | None = None
    """
    The task used for pooling.
    """

29
30
31
32
33
34
35
36
37
38
39
40
41
    pooling_type: SequencePoolingType | TokenPoolingType | None = None
    """
    The pooling method used for pooling.

    If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
    with this field. Alternatively, users can set `seq_pooling_type` and
    `tok_pooling_type` explicitly.

    This field is mainly for user convenience. Internal code should always use
    `seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
    """

    seq_pooling_type: SequencePoolingType | None = None
42
    """
43
44
45
46
47
48
    The pooling method used for sequence pooling.
    """

    tok_pooling_type: TokenPoolingType | None = None
    """
    The pooling method used for tokenwise pooling.
49
50
    """

51
    use_activation: bool | None = None
52
    """
53
54
    Whether to apply activation function to the pooler outputs.
    `None` uses the pooler's default, which is `True` in most cases.
55
    """
56
57

    ## for embedding models
58
    dimensions: int | None = None
59
60
61
62
    """
    Reduce the dimensions of embeddings if model
    support matryoshka representation. Defaults to None.
    """
63
    enable_chunked_processing: bool = False
64
65
66
67
68
69
70
    """
    Whether to enable chunked processing for long inputs that exceed the model's
    maximum position embeddings. When enabled, long inputs will be split into
    chunks, processed separately, and then aggregated using weighted averaging.
    This allows embedding models to handle arbitrarily long text without CUDA
    errors. Defaults to False.
    """
71
    max_embed_len: int | None = None
72
73
74
75
76
77
78
79
    """
    Maximum input length allowed for embedding generation. When set, allows
    inputs longer than max_embed_len to be accepted for embedding models.
    When an input exceeds max_embed_len, it will be handled according to 
    the original max_model_len validation logic. 
    Defaults to None (i.e. set to max_model_len).
    """

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    ## for classification models — affine score calibration
    logit_mean: float | None = None
    """
    If provided, subtract this value from classification logits before
    activation. Used for affine score calibration (Platt scaling):
    activation((logit - logit_mean) / logit_sigma). Defaults to None.
    """

    logit_sigma: float | None = None
    """
    If provided, divide the classification logits by this value after
    mean subtraction. Used for affine score calibration (Platt scaling):
    activation((logit - logit_mean) / logit_sigma). Defaults to None.
    """

    # Deprecated aliases — will be removed in v0.21
96
    logit_bias: float | None = None
97
    """
98
    Deprecated: Use logit_mean instead. Will be removed in v0.21.
99
100
    """

101
102
    logit_scale: float | None = None
    """
103
104
    Deprecated: Use logit_sigma instead (note: logit_sigma = 1/logit_scale).
    Will be removed in v0.21.
105
106
    """

107
    ## for reward models
108
    step_tag_id: int | None = None
109
    """
110
    If set, only the score corresponding to the `step_tag_id` in the
111
112
113
    generated sentence should be returned. Otherwise, the scores for all tokens
    are returned.
    """
114
    returned_token_ids: list[int] | None = None
115
116
    """
    A list of indices for the vocabulary dimensions to be extracted,
117
118
    such as the token IDs of `good_token` and `bad_token` in the
    `math-shepherd-mistral-7b-prm` model.
119
120
    """

121
    def __post_init__(self) -> None:
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        # Handle deprecated logit_bias → logit_mean
        if self.logit_bias is not None:
            if self.logit_mean is not None:
                raise ValueError(
                    "Cannot set both `logit_bias` and `logit_mean`. "
                    "`logit_bias` is deprecated, use `logit_mean` instead."
                )
            logger.warning(
                "`logit_bias` is deprecated and will be removed in v0.21. "
                "Use `logit_mean` instead."
            )
            self.logit_mean = self.logit_bias
            self.logit_bias = None

        # Handle deprecated logit_scale → logit_sigma
        if self.logit_scale is not None:
            if self.logit_sigma is not None:
                raise ValueError(
                    "Cannot set both `logit_scale` and `logit_sigma`. "
                    "`logit_scale` is deprecated, use `logit_sigma` instead."
                )
            logger.warning(
                "`logit_scale` is deprecated and will be removed in v0.21. "
                "Use `logit_sigma` instead (logit_sigma = 1/logit_scale)."
            )
            if self.logit_scale == 0:
                raise ValueError("logit_scale cannot be 0 (division by zero)")
            self.logit_sigma = 1.0 / self.logit_scale
            self.logit_scale = None

        if self.logit_sigma is not None and self.logit_sigma == 0:
            raise ValueError("logit_sigma cannot be 0 (division by zero)")

155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        if pooling_type := self.pooling_type:
            if self.seq_pooling_type is not None:
                raise ValueError(
                    "Cannot set both `pooling_type` and `seq_pooling_type`"
                )
            if self.tok_pooling_type is not None:
                raise ValueError(
                    "Cannot set both `pooling_type` and `tok_pooling_type`"
                )

            if pooling_type in SEQ_POOLING_TYPES:
                logger.debug(
                    "Resolved `pooling_type=%r` to `seq_pooling_type=%r`.",
                    pooling_type,
                    pooling_type,
                )
171
                self.seq_pooling_type = pooling_type  # type: ignore[assignment]
172
173
174
175
176
177
            elif pooling_type in TOK_POOLING_TYPES:
                logger.debug(
                    "Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
                    pooling_type,
                    pooling_type,
                )
178
                self.tok_pooling_type = pooling_type  # type: ignore[assignment]
179
180
181
182
183
184
185
186
187
188
            else:
                raise NotImplementedError(pooling_type)

    def get_seq_pooling_type(self) -> SequencePoolingType:
        assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig"
        return self.seq_pooling_type

    def get_tok_pooling_type(self) -> TokenPoolingType:
        assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig"
        return self.tok_pooling_type
189

190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: list[Any] = []
205
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
206
        return hash_str