interfaces_base.py 5.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
12
13
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
    Literal,
    Optional,
    Protocol,
    Union,
    overload,
    runtime_checkable,
)
14
15
16
17
18
19
20
21
22

import torch
import torch.nn as nn
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
from vllm.utils import supports_kw

if TYPE_CHECKING:
23
    from vllm.config import VllmConfig
24
    from vllm.model_executor.layers.pooler import Pooler
25
26
27
else:
    VllmConfig = Any
    Pooler = Any
28
29
30
31
32

logger = init_logger(__name__)

# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
33
# which has T = list[torch.Tensor]
34
35
36
37
38
39
40
41
42
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)

# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes


@runtime_checkable
43
class VllmModel(Protocol[T_co]):
44
    """The interface required for all models in vLLM."""
45
46
47

    def __init__(
        self,
48
        vllm_config: VllmConfig,
49
        prefix: str = "",
50
    ) -> None: ...
51

52
53
54
55
56
57
58
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        """Apply token embeddings to `input_ids`."""
        ...

59
60
61
62
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
63
    ) -> T_co: ...
64
65


66
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
67
    model_init = model.__init__
68
    return supports_kw(model_init, "vllm_config")
69
70


71
def _check_vllm_model_get_input_embeddings(model: Union[type[object], object]) -> bool:
72
73
74
75
76
77
78
79
80
81
82
    model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
    if not callable(model_get_input_embeddings):
        logger.warning(
            "The model (%s) is missing the `get_input_embeddings` method.",
            model,
        )
        return False

    return True


83
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
84
85
86
87
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

88
    vllm_kws = ("input_ids", "positions")
89
    missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw))
90

91
    if missing_kws and (isinstance(model, type) and issubclass(model, nn.Module)):
92
93
        logger.warning(
            "The model (%s) is missing "
94
            "vLLM-specific keywords from its `forward` method: %s",
95
96
97
98
99
100
101
102
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


@overload
103
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: ...
104
105
106


@overload
107
def is_vllm_model(model: object) -> TypeIs[VllmModel]: ...
108
109
110


def is_vllm_model(
111
112
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
113
114
115
116
117
    return (
        _check_vllm_model_init(model)
        and _check_vllm_model_get_input_embeddings(model)
        and _check_vllm_model_forward(model)
    )
118
119
120


@runtime_checkable
121
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
122
    """The interface required for all generative models in vLLM."""
123
124
125
126
127
128
129
130
131
132
133

    def compute_logits(
        self,
        hidden_states: T,
    ) -> Optional[T]:
        """Return `None` if TP rank > 0."""
        ...


@overload
def is_text_generation_model(
134
135
    model: type[object],
) -> TypeIs[type[VllmModelForTextGeneration]]: ...
136
137
138


@overload
139
def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration]: ...
140
141
142


def is_text_generation_model(
143
    model: Union[type[object], object],
144
145
146
) -> Union[
    TypeIs[type[VllmModelForTextGeneration]], TypeIs[VllmModelForTextGeneration]
]:
147
148
149
150
151
152
153
154
155
156
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)


@runtime_checkable
157
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
158
    """The interface required for all pooling models in vLLM."""
159

160
161
162
163
164
165
166
167
168
    is_pooling_model: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pooling.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

169
170
171
172
173
174
175
176
177
178
179
    default_pooling_type: ClassVar[str] = "LAST"
    """
    Indicates the
    [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.default_pooling_type][]
    decorator to conveniently set this field.
    """

180
    pooler: Pooler
181
    """The pooler is only called on TP rank 0."""
182
183
184


@overload
185
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: ...
186
187
188


@overload
189
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ...
190
191


192
def is_pooling_model(
193
194
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
195
196
197
    if not is_vllm_model(model):
        return False

198
    return getattr(model, "is_pooling_model", False)
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215


_T = TypeVar("_T", bound=type[nn.Module])


def default_pooling_type(pooling_type: str):
    """Decorator to set `VllmModelForPooling.default_pooling_type`."""

    def func(model: _T) -> _T:
        model.default_pooling_type = pooling_type  # type: ignore
        return model

    return func


def get_default_pooling_type(model: Union[type[object], object]) -> str:
    return getattr(model, "default_pooling_type", "LAST")