interfaces_base.py 5.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Protocol,
4
                    Union, overload, runtime_checkable)
5
6
7
8
9
10
11
12
13

import torch
import torch.nn as nn
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
from vllm.utils import supports_kw

if TYPE_CHECKING:
14
    from vllm.config import VllmConfig
15
    from vllm.model_executor.layers.pooler import Pooler
16
17
18
else:
    VllmConfig = Any
    Pooler = Any
19
20
21
22
23

logger = init_logger(__name__)

# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
24
# which has T = list[torch.Tensor]
25
26
27
28
29
30
31
32
33
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)

# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes


@runtime_checkable
34
class VllmModel(Protocol[T_co]):
35
    """The interface required for all models in vLLM."""
36
37
38

    def __init__(
        self,
39
        vllm_config: VllmConfig,
40
        prefix: str = "",
41
42
43
    ) -> None:
        ...

44
45
46
47
48
49
50
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        """Apply token embeddings to `input_ids`."""
        ...

51
52
53
54
55
56
57
58
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> T_co:
        ...


59
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
60
    model_init = model.__init__
61
    return supports_kw(model_init, "vllm_config")
62
63


64
65
66
67
68
69
70
71
72
73
74
75
76
def _check_vllm_model_get_input_embeddings(
        model: Union[type[object], object]) -> bool:
    model_get_input_embeddings = getattr(model, "get_input_embeddings", None)
    if not callable(model_get_input_embeddings):
        logger.warning(
            "The model (%s) is missing the `get_input_embeddings` method.",
            model,
        )
        return False

    return True


77
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
78
79
80
81
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

82
    vllm_kws = ("input_ids", "positions")
83
84
85
86
87
88
89
    missing_kws = tuple(kw for kw in vllm_kws
                        if not supports_kw(model_forward, kw))

    if missing_kws and (isinstance(model, type)
                        and issubclass(model, nn.Module)):
        logger.warning(
            "The model (%s) is missing "
90
            "vLLM-specific keywords from its `forward` method: %s",
91
92
93
94
95
96
97
98
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


@overload
99
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]:
100
101
102
103
104
105
106
107
108
    ...


@overload
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
    ...


def is_vllm_model(
109
110
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
111
112
113
    return (_check_vllm_model_init(model)
            and _check_vllm_model_get_input_embeddings(model)
            and _check_vllm_model_forward(model))
114
115
116


@runtime_checkable
117
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
118
    """The interface required for all generative models in vLLM."""
119
120
121
122
123
124
125
126
127
128
129

    def compute_logits(
        self,
        hidden_states: T,
    ) -> Optional[T]:
        """Return `None` if TP rank > 0."""
        ...


@overload
def is_text_generation_model(
130
        model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]:
131
132
133
134
135
136
137
138
139
140
    ...


@overload
def is_text_generation_model(
        model: object) -> TypeIs[VllmModelForTextGeneration]:
    ...


def is_text_generation_model(
141
142
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForTextGeneration]],
143
144
145
146
147
148
149
150
151
152
153
           TypeIs[VllmModelForTextGeneration]]:
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)


@runtime_checkable
154
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
155
    """The interface required for all pooling models in vLLM."""
156

157
158
159
160
161
162
163
164
165
    is_pooling_model: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pooling.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

166
167
168
169
170
171
172
173
174
175
176
    default_pooling_type: ClassVar[str] = "LAST"
    """
    Indicates the
    [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.default_pooling_type][]
    decorator to conveniently set this field.
    """

177
    pooler: Pooler
178
    """The pooler is only called on TP rank 0."""
179
180
181


@overload
182
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]:
183
184
185
186
    ...


@overload
187
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
188
189
190
    ...


191
def is_pooling_model(
192
193
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
194
195
196
    if not is_vllm_model(model):
        return False

197
    return getattr(model, "is_pooling_model", False)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214


_T = TypeVar("_T", bound=type[nn.Module])


def default_pooling_type(pooling_type: str):
    """Decorator to set `VllmModelForPooling.default_pooling_type`."""

    def func(model: _T) -> _T:
        model.default_pooling_type = pooling_type  # type: ignore
        return model

    return func


def get_default_pooling_type(model: Union[type[object], object]) -> str:
    return getattr(model, "default_pooling_type", "LAST")