pooler.py 4.27 KB
Newer Older
1
2
3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

4
from typing import Any, Literal
5
6
7
8

from pydantic.dataclasses import dataclass

from vllm.config.utils import config
9
from vllm.logger import init_logger
10
from vllm.utils.hashing import safe_hash
11
12

logger = init_logger(__name__)
13

14
15
PoolingTypeStr = Literal["LAST", "ALL", "CLS", "STEP", "MEAN"]

16
17
18
19
20
21

@config
@dataclass
class PoolerConfig:
    """Controls the behavior of output pooling in pooling models."""

22
    pooling_type: PoolingTypeStr | None = None
23
    """
24
    The pooling method of the pooling model.
25
26
27
    """

    ## for embeddings models
28
    normalize: bool | None = None
29
30
31
    """
    Whether to normalize the embeddings outputs. Defaults to True.
    """
32
    dimensions: int | None = None
33
34
35
36
    """
    Reduce the dimensions of embeddings if model
    support matryoshka representation. Defaults to None.
    """
37
    enable_chunked_processing: bool | None = None
38
39
40
41
42
43
44
    """
    Whether to enable chunked processing for long inputs that exceed the model's
    maximum position embeddings. When enabled, long inputs will be split into
    chunks, processed separately, and then aggregated using weighted averaging.
    This allows embedding models to handle arbitrarily long text without CUDA
    errors. Defaults to False.
    """
45
    max_embed_len: int | None = None
46
47
48
49
50
51
52
53
54
    """
    Maximum input length allowed for embedding generation. When set, allows
    inputs longer than max_embed_len to be accepted for embedding models.
    When an input exceeds max_embed_len, it will be handled according to 
    the original max_model_len validation logic. 
    Defaults to None (i.e. set to max_model_len).
    """

    ## for classification models
55
56
57
58
59
60
61
62
63
    softmax: float | None = None
    """
    softmax will be deprecated, please use use_activation instead.
    """
    activation: float | None = None
    """
    activation will be deprecated, please use use_activation instead.
    """
    use_activation: bool | None = None
64
65
66
67
    """
    Whether to apply activation function to the classification outputs.
    Defaults to True.
    """
68
    logit_bias: float | None = None
69
70
71
72
73
    """
    If provided, apply classification logit biases. Defaults to None.
    """

    ## for reward models
74
    step_tag_id: int | None = None
75
    """
76
    If set, only the score corresponding to the `step_tag_id` in the
77
78
79
    generated sentence should be returned. Otherwise, the scores for all tokens
    are returned.
    """
80
    returned_token_ids: list[int] | None = None
81
82
    """
    A list of indices for the vocabulary dimensions to be extracted,
83
84
    such as the token IDs of `good_token` and `bad_token` in the
    `math-shepherd-mistral-7b-prm` model.
85
86
    """

87
88
89
90
    def __post_init__(self):
        # raise deprecated warning for softmax and activation
        self.use_activation = get_use_activation(self)

91
92
93
94
    def get_pooling_type(self) -> PoolingTypeStr:
        assert self.pooling_type is not None, "Should be resolved by ModelConfig"
        return self.pooling_type

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    def compute_hash(self) -> str:
        """
        WARNING: Whenever a new field is added to this config,
        ensure that it is included in the factors list if
        it affects the computation graph.

        Provide a hash that uniquely identifies all the configs
        that affect the structure of the computation
        graph from input ids/embeddings to the final hidden states,
        excluding anything before input ids/embeddings and after
        the final hidden states.
        """
        # no factors to consider.
        # this config will not affect the computation graph.
        factors: list[Any] = []
110
        hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
111
        return hash_str
112
113
114
115
116


def get_use_activation(o: object):
    if softmax := getattr(o, "softmax", None) is not None:
        logger.warning_once(
117
118
            "softmax will be deprecated and will be removed in v0.15. "
            "Please use use_activation instead."
119
120
121
122
123
        )
        return softmax

    if activation := getattr(o, "activation", None) is not None:
        logger.warning_once(
124
125
            "activation will be deprecated and will be removed in v0.15. "
            "Please use use_activation instead."
126
127
128
129
        )
        return activation

    return getattr(o, "use_activation", None)