interfaces_base.py 4.98 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Protocol,
4
                    Union, overload, runtime_checkable)
5
6
7
8
9
10
11
12
13

import torch
import torch.nn as nn
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
from vllm.utils import supports_kw

if TYPE_CHECKING:
14
    from vllm.config import VllmConfig
15
    from vllm.model_executor.layers.pooler import Pooler
16
17
18
else:
    VllmConfig = Any
    Pooler = Any
19
20
21
22
23

logger = init_logger(__name__)

# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
24
# which has T = list[torch.Tensor]
25
26
27
28
29
30
31
32
33
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)

# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes


@runtime_checkable
34
class VllmModel(Protocol[T_co]):
35
    """The interface required for all models in vLLM."""
36
37
38

    def __init__(
        self,
39
        vllm_config: VllmConfig,
40
        prefix: str = "",
41
42
43
44
45
46
47
48
49
50
51
    ) -> None:
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> T_co:
        ...


52
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
53
    model_init = model.__init__
54
    return supports_kw(model_init, "vllm_config")
55
56


57
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
58
59
60
61
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

62
    vllm_kws = ("input_ids", "positions")
63
64
65
66
67
68
69
    missing_kws = tuple(kw for kw in vllm_kws
                        if not supports_kw(model_forward, kw))

    if missing_kws and (isinstance(model, type)
                        and issubclass(model, nn.Module)):
        logger.warning(
            "The model (%s) is missing "
70
            "vLLM-specific keywords from its `forward` method: %s",
71
72
73
74
75
76
77
78
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


@overload
79
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]:
80
81
82
83
84
85
86
87
88
    ...


@overload
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
    ...


def is_vllm_model(
89
90
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
91
92
93
94
    return _check_vllm_model_init(model) and _check_vllm_model_forward(model)


@runtime_checkable
95
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
96
    """The interface required for all generative models in vLLM."""
97
98
99
100
101
102
103
104
105
106
107

    def compute_logits(
        self,
        hidden_states: T,
    ) -> Optional[T]:
        """Return `None` if TP rank > 0."""
        ...


@overload
def is_text_generation_model(
108
        model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]:
109
110
111
112
113
114
115
116
117
118
    ...


@overload
def is_text_generation_model(
        model: object) -> TypeIs[VllmModelForTextGeneration]:
    ...


def is_text_generation_model(
119
120
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForTextGeneration]],
121
122
123
124
125
126
127
128
129
130
131
           TypeIs[VllmModelForTextGeneration]]:
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)


@runtime_checkable
132
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
133
    """The interface required for all pooling models in vLLM."""
134

135
136
137
138
139
140
141
142
143
    is_pooling_model: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pooling.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

144
145
146
147
148
149
150
151
152
153
154
    default_pooling_type: ClassVar[str] = "LAST"
    """
    Indicates the
    [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.default_pooling_type][]
    decorator to conveniently set this field.
    """

155
    pooler: Pooler
156
    """The pooler is only called on TP rank 0."""
157
158
159


@overload
160
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]:
161
162
163
164
    ...


@overload
165
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
166
167
168
    ...


169
def is_pooling_model(
170
171
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
172
173
174
    if not is_vllm_model(model):
        return False

175
    return getattr(model, "is_pooling_model", False)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192


_T = TypeVar("_T", bound=type[nn.Module])


def default_pooling_type(pooling_type: str):
    """Decorator to set `VllmModelForPooling.default_pooling_type`."""

    def func(model: _T) -> _T:
        model.default_pooling_type = pooling_type  # type: ignore
        return model

    return func


def get_default_pooling_type(model: Union[type[object], object]) -> str:
    return getattr(model, "default_pooling_type", "LAST")