interfaces_base.py 5.43 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
    Literal,
    Protocol,
    overload,
    runtime_checkable,
)
12
13
14
15
16
17

import torch
import torch.nn as nn
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
18
from vllm.utils.func import supports_kw
19
20

if TYPE_CHECKING:
21
    from vllm.config import VllmConfig
22
    from vllm.model_executor.layers.pooler import Pooler
23
24
25
else:
    VllmConfig = Any
    Pooler = Any
26
27
28
29
30

logger = init_logger(__name__)

# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
31
# which has T = list[torch.Tensor]
32
33
34
35
36
37
38
39
40
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)

# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes


@runtime_checkable
41
class VllmModel(Protocol[T_co]):
42
    """The interface required for all models in vLLM."""
43
44
45

    def __init__(
        self,
46
        vllm_config: VllmConfig,
47
        prefix: str = "",
48
    ) -> None: ...
49

50
51
52
53
54
55
56
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        """Apply token embeddings to `input_ids`."""
        ...

57
58
59
60
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
61
    ) -> T_co: ...
62
63


64
def _check_vllm_model_init(model: type[object] | object) -> bool:
65
    model_init = model.__init__
66
    return supports_kw(model_init, "vllm_config")
67
68


69
def _check_vllm_model_get_input_embeddings(model: type[object] | object) -> bool:
70
71
72
73
74
75
76
77
78
79
80
    model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
    if not callable(model_get_input_embeddings):
        logger.warning(
            "The model (%s) is missing the `get_input_embeddings` method.",
            model,
        )
        return False

    return True


81
def _check_vllm_model_forward(model: type[object] | object) -> bool:
82
83
84
85
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

86
    vllm_kws = ("input_ids", "positions")
87
    missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw))
88

89
    if missing_kws and (isinstance(model, type) and issubclass(model, nn.Module)):
90
91
        logger.warning(
            "The model (%s) is missing "
92
            "vLLM-specific keywords from its `forward` method: %s",
93
94
95
96
97
98
99
100
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


@overload
101
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: ...
102
103
104


@overload
105
def is_vllm_model(model: object) -> TypeIs[VllmModel]: ...
106
107
108


def is_vllm_model(
109
110
    model: type[object] | object,
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
111
112
113
114
115
    return (
        _check_vllm_model_init(model)
        and _check_vllm_model_get_input_embeddings(model)
        and _check_vllm_model_forward(model)
    )
116
117
118


@runtime_checkable
119
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
120
    """The interface required for all generative models in vLLM."""
121
122
123
124

    def compute_logits(
        self,
        hidden_states: T,
125
    ) -> T | None:
126
127
128
129
130
131
        """Return `None` if TP rank > 0."""
        ...


@overload
def is_text_generation_model(
132
133
    model: type[object],
) -> TypeIs[type[VllmModelForTextGeneration]]: ...
134
135
136


@overload
137
def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration]: ...
138
139
140


def is_text_generation_model(
141
142
    model: type[object] | object,
) -> TypeIs[type[VllmModelForTextGeneration]] | TypeIs[VllmModelForTextGeneration]:
143
144
145
146
147
148
149
150
151
152
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)


@runtime_checkable
153
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
154
    """The interface required for all pooling models in vLLM."""
155

156
157
158
159
160
161
162
163
164
    is_pooling_model: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pooling.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

165
166
167
168
169
170
171
172
173
174
175
    default_pooling_type: ClassVar[str] = "LAST"
    """
    Indicates the
    [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.default_pooling_type][]
    decorator to conveniently set this field.
    """

176
    pooler: Pooler
177
    """The pooler is only called on TP rank 0."""
178
179
180


@overload
181
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: ...
182
183
184


@overload
185
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ...
186
187


188
def is_pooling_model(
189
190
    model: type[object] | object,
) -> TypeIs[type[VllmModelForPooling]] | TypeIs[VllmModelForPooling]:
191
192
193
    if not is_vllm_model(model):
        return False

194
    return getattr(model, "is_pooling_model", False)
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209


_T = TypeVar("_T", bound=type[nn.Module])


def default_pooling_type(pooling_type: str):
    """Decorator to set `VllmModelForPooling.default_pooling_type`."""

    def func(model: _T) -> _T:
        model.default_pooling_type = pooling_type  # type: ignore
        return model

    return func


210
def get_default_pooling_type(model: type[object] | object) -> str:
211
    return getattr(model, "default_pooling_type", "LAST")