interfaces_base.py 6.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11
from typing import (
    TYPE_CHECKING,
    Any,
    ClassVar,
    Literal,
    Protocol,
    overload,
    runtime_checkable,
)
12
13
14
15
16
17

import torch
import torch.nn as nn
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
18
from vllm.utils.func_utils import supports_kw
19
20

if TYPE_CHECKING:
21
    from vllm.config import VllmConfig
22
    from vllm.model_executor.layers.pooler import Pooler
23
24
25
else:
    VllmConfig = Any
    Pooler = Any
26
27
28
29
30

logger = init_logger(__name__)

# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
31
# which has T = list[torch.Tensor]
32
33
34
35
36
37
38
39
40
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)

# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes


@runtime_checkable
41
class VllmModel(Protocol[T_co]):
42
    """The interface required for all models in vLLM."""
43

44
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: ...
45

46
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
47
        """Apply token embeddings to `input_ids`."""
48
49
50
51
52
53
54
        if hasattr(self, "get_input_embeddings"):
            logger.warning_once(
                "`get_input_embeddings` for vLLM models is deprecated and will be "
                "removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
                "this method to `embed_input_ids`."
            )
            return self.get_input_embeddings(input_ids)
55

56
    def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> T_co: ...
57
58


59
def _check_vllm_model_init(model: type[object] | object) -> bool:
60
    model_init = model.__init__
61
    return supports_kw(model_init, "vllm_config")
62
63


64
65
66
67
68
69
70
71
72
73
74
def _check_vllm_model_embed_input_ids(model: type[object] | object) -> bool:
    model_embed_input_ids = getattr(model, "embed_input_ids", None)
    if not callable(model_embed_input_ids):
        model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
        if callable(model_get_input_embeddings):
            logger.warning(
                "`get_input_embeddings` for vLLM models is deprecated and will be "
                "removed in v0.13.0 or v1.0.0, whichever is earlier. Please rename "
                "this method to `embed_input_ids`."
            )
            model.embed_input_ids = model_get_input_embeddings
75
        logger.warning(
76
            "The model (%s) is missing the `embed_input_ids` method.",
77
78
79
80
81
82
83
            model,
        )
        return False

    return True


84
def _check_vllm_model_forward(model: type[object] | object) -> bool:
85
86
87
88
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

89
    vllm_kws = ("input_ids", "positions")
90
    missing_kws = tuple(kw for kw in vllm_kws if not supports_kw(model_forward, kw))
91

92
    if missing_kws and (isinstance(model, type) and issubclass(model, nn.Module)):
93
94
        logger.warning(
            "The model (%s) is missing "
95
            "vLLM-specific keywords from its `forward` method: %s",
96
97
98
99
100
101
102
103
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


@overload
104
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]: ...
105
106
107


@overload
108
def is_vllm_model(model: object) -> TypeIs[VllmModel]: ...
109
110
111


def is_vllm_model(
112
113
    model: type[object] | object,
) -> TypeIs[type[VllmModel]] | TypeIs[VllmModel]:
114
115
    return (
        _check_vllm_model_init(model)
116
        and _check_vllm_model_embed_input_ids(model)
117
118
        and _check_vllm_model_forward(model)
    )
119
120
121


@runtime_checkable
122
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
123
    """The interface required for all generative models in vLLM."""
124
125
126
127

    def compute_logits(
        self,
        hidden_states: T,
128
    ) -> T | None:
129
130
131
132
133
134
        """Return `None` if TP rank > 0."""
        ...


@overload
def is_text_generation_model(
135
136
    model: type[object],
) -> TypeIs[type[VllmModelForTextGeneration]]: ...
137
138
139


@overload
140
def is_text_generation_model(model: object) -> TypeIs[VllmModelForTextGeneration]: ...
141
142
143


def is_text_generation_model(
144
145
    model: type[object] | object,
) -> TypeIs[type[VllmModelForTextGeneration]] | TypeIs[VllmModelForTextGeneration]:
146
147
148
149
150
151
152
153
154
155
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)


@runtime_checkable
156
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
157
    """The interface required for all pooling models in vLLM."""
158

159
160
161
162
163
164
165
166
167
    is_pooling_model: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pooling.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

168
169
170
171
172
173
174
175
176
177
178
    default_pooling_type: ClassVar[str] = "LAST"
    """
    Indicates the
    [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.default_pooling_type][]
    decorator to conveniently set this field.
    """

179
    pooler: Pooler
180
    """The pooler is only called on TP rank 0."""
181
182
183


@overload
184
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]: ...
185
186
187


@overload
188
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]: ...
189
190


191
def is_pooling_model(
192
193
    model: type[object] | object,
) -> TypeIs[type[VllmModelForPooling]] | TypeIs[VllmModelForPooling]:
194
195
196
    if not is_vllm_model(model):
        return False

197
    return getattr(model, "is_pooling_model", False)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212


_T = TypeVar("_T", bound=type[nn.Module])


def default_pooling_type(pooling_type: str):
    """Decorator to set `VllmModelForPooling.default_pooling_type`."""

    def func(model: _T) -> _T:
        model.default_pooling_type = pooling_type  # type: ignore
        return model

    return func


213
def get_default_pooling_type(model: type[object] | object) -> str:
214
    return getattr(model, "default_pooling_type", "LAST")