interfaces_base.py 6.78 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
    Literal,
    Protocol,
    overload,
    runtime_checkable,
)
12
13
14
15
16
17

import torch
import torch.nn as nn
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
18
from vllm.utils.func_utils import supports_kw
19
20

if TYPE_CHECKING:
21
    from vllm.config import VllmConfig
22
    from vllm.config.model import AttnTypeStr
23
    from vllm.config.pooler import SequencePoolingType, TokenPoolingType
24
    from vllm.model_executor.layers.pooler import Pooler
25
26
27
else:
    VllmConfig = Any
    Pooler = Any
28
29
    SequencePoolingType = Any
    TokenPoolingType = Any
30
    AttnTypeStr = Any
31
32
33
34
35

logger = init_logger(__name__)

# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
36
# which has T = list[torch.Tensor]
37
38
39
40
41
42
43
44
45
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)

# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes


@runtime_checkable
46
class VllmModel(Protocol[T_co]):
47
    """The interface required for all models in vLLM."""
48

49
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: ...
50

51
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
52
        """Apply token embeddings to `input_ids`."""
53
        ...
54

55
    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> T_co: ...
56
57


58
def _check_vllm_model_init(model: type[object] | object) -> bool:
59
    model_init = model.__init__
60
    return supports_kw(model_init, "vllm_config")
61
62


63
64
65
def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool:
    model_embed_input_ids = getattr(model, "embed_input_ids", None)
    if not callable(model_embed_input_ids):
66
        logger.warning(
67
            "The model (%s) is missing the `embed_input_ids` method.",
68
69
70
71
72
73
74
            model,
        )
        return False

    return True


75
def _check_vllm_model_forward(model: type[object] | object) -> bool:
76
77
78
79
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

80
    vllm_kws = ("input_ids", "positions")
81
    missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw))
82

83
    if missing_kws and (isinstance(model, type) and issubclass(model, nn.Module)):
84
85
        logger.warning(
            "The model (%s) is missing "
86
            "vLLM-specific keywords from its `forward` method: %s",
87
88
89
90
91
92
93
94
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


@overload
95
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: ...
96
97
98


@overload
99
def is_vllm_model(model: object) -> TypeIs[VllmModel]: ...
100
101
102


def is_vllm_model(
103
104
    model: type[object] | object,
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
105
106
    return (
        _check_vllm_model_init(model)
107
        and _check_vllm_model_embed_input_ids(model)
108
109
        and _check_vllm_model_forward(model)
    )
110
111
112


@runtime_checkable
113
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
114
    """The interface required for all generative models in vLLM."""
115
116
117
118

    def compute_logits(
        self,
        hidden_states: T,
119
    ) -> T | None:
120
121
122
123
124
125
        """Return `None` if TP rank > 0."""
        ...


@overload
def is_text_generation_model(
126
127
    model: type[object],
) -> TypeIs[type[VllmModelForTextGeneration]]: ...
128
129
130


@overload
131
def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration]: ...
132
133
134


def is_text_generation_model(
135
136
    model: type[object] | object,
) -> TypeIs[type[VllmModelForTextGeneration]] | TypeIs[VllmModelForTextGeneration]:
137
138
139
140
141
142
143
144
145
146
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)


@runtime_checkable
147
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
148
    """The interface required for all pooling models in vLLM."""
149

150
151
152
153
154
155
156
157
158
    is_pooling_model: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pooling.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

159
    default_seq_pooling_type: ClassVar[SequencePoolingType] = "LAST"
160
    """
161
162
163
164
165
166
167
168
169
170
171
    Indicates the [vllm.config.pooler.PoolerConfig.seq_pooling_type][]
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.default_pooling_type][]
    decorator to conveniently set this field.
    """

    default_tok_pooling_type: ClassVar[TokenPoolingType] = "ALL"
    """
    Indicates the [vllm.config.pooler.PoolerConfig.tok_pooling_type][]
172
173
174
175
176
177
178
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.default_pooling_type][]
    decorator to conveniently set this field.
    """

179
180
181
182
183
184
185
186
187
188
189
    attn_type: ClassVar[AttnTypeStr] = "decoder"
    """
    Indicates the
    [vllm.config.model.ModelConfig.attn_type][]
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.attn_type][]
    decorator to conveniently set this field.
    """

190
    pooler: Pooler
191
    """The pooler is only called on TP rank 0."""
192
193
194


@overload
195
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: ...
196
197
198


@overload
199
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ...
200
201


202
def is_pooling_model(
203
204
    model: type[object] | object,
) -> TypeIs[type[VllmModelForPooling]] | TypeIs[VllmModelForPooling]:
205
206
207
    if not is_vllm_model(model):
        return False

208
    return getattr(model, "is_pooling_model", False)
209
210
211
212
213


_T = TypeVar("_T", bound=type[nn.Module])


214
215
216
217
218
219
def default_pooling_type(
    *,
    seq_pooling_type: SequencePoolingType = "LAST",
    tok_pooling_type: TokenPoolingType = "ALL",
):
    """Decorator to set `VllmModelForPooling.default_*_pooling_type`."""
220
221

    def func(model: _T) -> _T:
222
223
        model.default_seq_pooling_type = seq_pooling_type  # type: ignore
        model.default_tok_pooling_type = tok_pooling_type  # type: ignore
224
225
226
227
228
        return model

    return func


229
230
231
232
233
234
235
236
237
238
def get_default_seq_pooling_type(
    model: type[object] | object,
) -> SequencePoolingType:
    return getattr(model, "default_seq_pooling_type", "LAST")


def get_default_tok_pooling_type(
    model: type[object] | object,
) -> TokenPoolingType:
    return getattr(model, "default_tok_pooling_type", "ALL")
239
240
241
242
243
244
245
246
247
248
249
250
251
252


def attn_type(attn_type: AttnTypeStr):
    """Decorator to set `VllmModelForPooling.attn_type`."""

    def func(model: _T) -> _T:
        model.attn_type = attn_type  # type: ignore
        return model

    return func


def get_attn_type(model: type[object] | object) -> AttnTypeStr:
    return getattr(model, "attn_type", "decoder")