pooler.py 6.3 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from typing import Any, Literal, get_args
5
6
7
8

from pydantic.dataclasses import dataclass

from vllm.config.utils import config
9
from vllm.logger import init_logger
10
from vllm.utils.hashing import safe_hash
11
12

logger = init_logger(__name__)
13

14
15
16
17
18
SequencePoolingType = Literal["CLS", "LAST", "MEAN"]
SEQ_POOLING_TYPES: tuple[SequencePoolingType, ...] = get_args(SequencePoolingType)

TokenPoolingType = Literal["ALL", "STEP"]
TOK_POOLING_TYPES: tuple[TokenPoolingType, ...] = get_args(TokenPoolingType)
19

20
21
22
23
24
25

@config
@dataclass
class PoolerConfig:
    """Controls the behavior of output pooling in pooling models."""

26
27
28
29
30
31
32
33
34
35
36
37
38
    pooling_type: SequencePoolingType | TokenPoolingType | None = None
    """
    The pooling method used for pooling.

    If set, `seq_pooling_type` or `tok_pooling_type` are automatically populated
    with this field. Alternatively, users can set `seq_pooling_type` and
    `tok_pooling_type` explicitly.

    This field is mainly for user convenience. Internal code should always use
    `seq_pooling_type` or `tok_pooling_type` instead of `pooling_type`.
    """

    seq_pooling_type: SequencePoolingType | None = None
39
    """
40
41
42
43
44
45
    The pooling method used for sequence pooling.
    """

    tok_pooling_type: TokenPoolingType | None = None
    """
    The pooling method used for tokenwise pooling.
46
47
48
    """

    ## for embeddings models
49
    normalize: bool | None = None
50
51
52
    """
    Whether to normalize the embeddings outputs. Defaults to True.
    """
53
    dimensions: int | None = None
54
55
56
57
    """
    Reduce the dimensions of embeddings if model
    support matryoshka representation. Defaults to None.
    """
58
    enable_chunked_processing: bool | None = None
59
60
61
62
63
64
65
    """
    Whether to enable chunked processing for long inputs that exceed the model's
    maximum position embeddings. When enabled, long inputs will be split into
    chunks, processed separately, and then aggregated using weighted averaging.
    This allows embedding models to handle arbitrarily long text without CUDA
    errors. Defaults to False.
    """
66
    max_embed_len: int | None = None
67
68
69
70
71
72
73
74
75
    """
    Maximum input length allowed for embedding generation. When set, allows
    inputs longer than max_embed_len to be accepted for embedding models.
    When an input exceeds max_embed_len, it will be handled according to 
    the original max_model_len validation logic. 
    Defaults to None (i.e. set to max_model_len).
    """

    ## for classification models
76
77
78
79
80
81
82
83
84
    softmax: float | None = None
    """
    softmax will be deprecated, please use use_activation instead.
    """
    activation: float | None = None
    """
    activation will be deprecated, please use use_activation instead.
    """
    use_activation: bool | None = None
85
86
87
88
    """
    Whether to apply activation function to the classification outputs.
    Defaults to True.
    """
89
    logit_bias: float | None = None
90
91
92
93
94
    """
    If provided, apply classification logit biases. Defaults to None.
    """

    ## for reward models
95
    step_tag_id: int | None = None
96
    """
97
    If set, only the score corresponding to the `step_tag_id` in the
98
99
100
    generated sentence should be returned. Otherwise, the scores for all tokens
    are returned.
    """
101
    returned_token_ids: list[int] | None = None
102
103
    """
    A list of indices for the vocabulary dimensions to be extracted,
104
105
    such as the token IDs of `good_token` and `bad_token` in the
    `math-shepherd-mistral-7b-prm` model.
106
107
    """

108
109
110
111
    def __post_init__(self):
        # raise deprecated warning for softmax and activation
        self.use_activation = get_use_activation(self)

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
        if pooling_type := self.pooling_type:
            if self.seq_pooling_type is not None:
                raise ValueError(
                    "Cannot set both `pooling_type` and `seq_pooling_type`"
                )
            if self.tok_pooling_type is not None:
                raise ValueError(
                    "Cannot set both `pooling_type` and `tok_pooling_type`"
                )

            if pooling_type in SEQ_POOLING_TYPES:
                logger.debug(
                    "Resolved `pooling_type=%r` to `seq_pooling_type=%r`.",
                    pooling_type,
                    pooling_type,
                )
                self.seq_pooling_type = pooling_type
            elif pooling_type in TOK_POOLING_TYPES:
                logger.debug(
                    "Resolved `pooling_type=%r` to `tok_pooling_type=%r`.",
                    pooling_type,
                    pooling_type,
                )
                self.tok_pooling_type = pooling_type
            else:
                raise NotImplementedError(pooling_type)

    def get_seq_pooling_type(self) -> SequencePoolingType:
        assert self.seq_pooling_type is not None, "Should be resolved by ModelConfig"
        return self.seq_pooling_type

    def get_tok_pooling_type(self) -> TokenPoolingType:
        assert self.tok_pooling_type is not None, "Should be resolved by ModelConfig"
        return self.tok_pooling_type
146

147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: list[Any] = []
162
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
163
        return hash_str
164
165
166
167
168


def get_use_activation(o: object):
    if softmax := getattr(o, "softmax", None) is not None:
        logger.warning_once(
169
170
            "softmax will be deprecated and will be removed in v0.15. "
            "Please use use_activation instead."
171
172
173
174
175
        )
        return softmax

    if activation := getattr(o, "activation", None) is not None:
        logger.warning_once(
176
177
            "activation will be deprecated and will be removed in v0.15. "
            "Please use use_activation instead."
178
179
180
181
        )
        return activation

    return getattr(o, "use_activation", None)