interfaces_base.py 6.43 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
    Literal,
    Protocol,
    overload,
    runtime_checkable,
)
12
13
14
15
16
17

import torch
import torch.nn as nn
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
18
from vllm.utils.func_utils import supports_kw
19
20

if TYPE_CHECKING:
21
    from vllm.config import VllmConfig
22
23
    from vllm.config.model import AttnTypeStr
    from vllm.config.pooler import PoolingTypeStr
24
    from vllm.model_executor.layers.pooler import Pooler
25
26
27
else:
    VllmConfig = Any
    Pooler = Any
28
29
    PoolingTypeStr = Any
    AttnTypeStr = Any
30
31
32
33
34

logger = init_logger(__name__)

# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
35
# which has T = list[torch.Tensor]
36
37
38
39
40
41
42
43
44
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)

# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes


@runtime_checkable
45
class VllmModel(Protocol[T_co]):
46
    """The interface required for all models in vLLM."""
47

48
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: ...
49

50
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
51
        """Apply token embeddings to `input_ids`."""
52
53
54
55
56
57
58
        if hasattr(self, "get_input_embeddings"):
            logger.warning_once(
                "`get_input_embeddings` for vLLM models is deprecated and will be "
                "removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
                "this method to `embed_input_ids`."
            )
            return self.get_input_embeddings(input_ids)
59

60
    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> T_co: ...
61
62


63
def _check_vllm_model_init(model: type[object] | object) -> bool:
64
    model_init = model.__init__
65
    return supports_kw(model_init, "vllm_config")
66
67


68
69
70
def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool:
    model_embed_input_ids = getattr(model, "embed_input_ids", None)
    if not callable(model_embed_input_ids):
71
        logger.warning(
72
            "The model (%s) is missing the `embed_input_ids` method.",
73
74
75
76
77
78
79
            model,
        )
        return False

    return True


80
def _check_vllm_model_forward(model: type[object] | object) -> bool:
81
82
83
84
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

85
    vllm_kws = ("input_ids", "positions")
86
    missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw))
87

88
    if missing_kws and (isinstance(model, type) and issubclass(model, nn.Module)):
89
90
        logger.warning(
            "The model (%s) is missing "
91
            "vLLM-specific keywords from its `forward` method: %s",
92
93
94
95
96
97
98
99
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


@overload
100
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: ...
101
102
103


@overload
104
def is_vllm_model(model: object) -> TypeIs[VllmModel]: ...
105
106
107


def is_vllm_model(
108
109
    model: type[object] | object,
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
110
111
    return (
        _check_vllm_model_init(model)
112
        and _check_vllm_model_embed_input_ids(model)
113
114
        and _check_vllm_model_forward(model)
    )
115
116
117


@runtime_checkable
118
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
119
    """The interface required for all generative models in vLLM."""
120
121
122
123

    def compute_logits(
        self,
        hidden_states: T,
124
    ) -> T | None:
125
126
127
128
129
130
        """Return `None` if TP rank > 0."""
        ...


@overload
def is_text_generation_model(
131
132
    model: type[object],
) -> TypeIs[type[VllmModelForTextGeneration]]: ...
133
134
135


@overload
136
def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration]: ...
137
138
139


def is_text_generation_model(
140
141
    model: type[object] | object,
) -> TypeIs[type[VllmModelForTextGeneration]] | TypeIs[VllmModelForTextGeneration]:
142
143
144
145
146
147
148
149
150
151
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)


@runtime_checkable
152
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
153
    """The interface required for all pooling models in vLLM."""
154

155
156
157
158
159
160
161
162
163
    is_pooling_model: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pooling.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

164
    default_pooling_type: ClassVar[PoolingTypeStr] = "LAST"
165
    """
166
    Indicates the [vllm.config.pooler.PoolerConfig.pooling_type][]
167
168
169
170
171
172
173
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.default_pooling_type][]
    decorator to conveniently set this field.
    """

174
175
176
177
178
179
180
181
182
183
184
    attn_type: ClassVar[AttnTypeStr] = "decoder"
    """
    Indicates the
    [vllm.config.model.ModelConfig.attn_type][]
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.attn_type][]
    decorator to conveniently set this field.
    """

185
    pooler: Pooler
186
    """The pooler is only called on TP rank 0."""
187
188
189


@overload
190
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: ...
191
192
193


@overload
194
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ...
195
196


197
def is_pooling_model(
198
199
    model: type[object] | object,
) -> TypeIs[type[VllmModelForPooling]] | TypeIs[VllmModelForPooling]:
200
201
202
    if not is_vllm_model(model):
        return False

203
    return getattr(model, "is_pooling_model", False)
204
205
206
207
208


_T = TypeVar("_T", bound=type[nn.Module])


209
def default_pooling_type(pooling_type: PoolingTypeStr):
210
211
212
213
214
215
216
217
218
    """Decorator to set `VllmModelForPooling.default_pooling_type`."""

    def func(model: _T) -> _T:
        model.default_pooling_type = pooling_type  # type: ignore
        return model

    return func


219
def get_default_pooling_type(model: type[object] | object) -> PoolingTypeStr:
220
    return getattr(model, "default_pooling_type", "LAST")
221
222
223
224
225
226
227
228
229
230
231
232
233
234


def attn_type(attn_type: AttnTypeStr):
    """Decorator to set `VllmModelForPooling.attn_type`."""

    def func(model: _T) -> _T:
        model.attn_type = attn_type  # type: ignore
        return model

    return func


def get_attn_type(model: type[object] | object) -> AttnTypeStr:
    return getattr(model, "attn_type", "decoder")