interfaces_base.py 5.12 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import (TYPE_CHECKING, Any, ClassVar, Literal, Optional, Protocol,
4
                    Union, overload, runtime_checkable)
5
6
7
8
9
10
11
12
13

import torch
import torch.nn as nn
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
from vllm.utils import supports_kw

if TYPE_CHECKING:
14
    from vllm.config import VllmConfig
15
    from vllm.model_executor.layers.pooler import Pooler
16
    from vllm.model_executor.sampling_metadata import SamplingMetadata
17
18
19
20
else:
    VllmConfig = Any
    Pooler = Any
    SamplingMetadata = Any
21
22
23
24
25

logger = init_logger(__name__)

# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
26
# which has T = list[torch.Tensor]
27
28
29
30
31
32
33
34
35
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)

# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes


@runtime_checkable
36
class VllmModel(Protocol[T_co]):
37
    """The interface required for all models in vLLM."""
38
39
40

    def __init__(
        self,
41
        vllm_config: VllmConfig,
42
        prefix: str = "",
43
44
45
46
47
48
49
50
51
52
53
    ) -> None:
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> T_co:
        ...


54
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
55
    model_init = model.__init__
56
    return supports_kw(model_init, "vllm_config")
57
58


59
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
60
61
62
63
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

64
    vllm_kws = ("input_ids", "positions")
65
66
67
68
69
70
71
    missing_kws = tuple(kw for kw in vllm_kws
                        if not supports_kw(model_forward, kw))

    if missing_kws and (isinstance(model, type)
                        and issubclass(model, nn.Module)):
        logger.warning(
            "The model (%s) is missing "
72
            "vLLM-specific keywords from its `forward` method: %s",
73
74
75
76
77
78
79
80
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


@overload
81
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]:
82
83
84
85
86
87
88
89
90
    ...


@overload
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
    ...


def is_vllm_model(
91
92
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
93
94
95
96
    return _check_vllm_model_init(model) and _check_vllm_model_forward(model)


@runtime_checkable
97
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
98
    """The interface required for all generative models in vLLM."""
99
100
101
102

    def compute_logits(
        self,
        hidden_states: T,
103
        sampling_metadata: SamplingMetadata,
104
105
106
107
108
109
110
    ) -> Optional[T]:
        """Return `None` if TP rank > 0."""
        ...


@overload
def is_text_generation_model(
111
        model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]:
112
113
114
115
116
117
118
119
120
121
    ...


@overload
def is_text_generation_model(
        model: object) -> TypeIs[VllmModelForTextGeneration]:
    ...


def is_text_generation_model(
122
123
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForTextGeneration]],
124
125
126
127
128
129
130
131
132
133
134
           TypeIs[VllmModelForTextGeneration]]:
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)


@runtime_checkable
135
class VllmModelForPooling(VllmModel[T_co], Protocol[T_co]):
136
    """The interface required for all pooling models in vLLM."""
137

138
139
140
141
142
143
144
145
146
    is_pooling_model: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pooling.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

147
148
149
150
151
152
153
154
155
156
157
    default_pooling_type: ClassVar[str] = "LAST"
    """
    Indicates the
    [vllm.model_executor.layers.pooler.PoolerConfig.pooling_type][]
    to use by default.

    You can use the
    [vllm.model_executor.models.interfaces_base.default_pooling_type][]
    decorator to conveniently set this field.
    """

158
    pooler: Pooler
159
    """The pooler is only called on TP rank 0."""
160
161
162


@overload
163
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]:
164
165
166
167
    ...


@overload
168
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
169
170
171
    ...


172
def is_pooling_model(
173
174
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
175
176
177
    if not is_vllm_model(model):
        return False

178
    return getattr(model, "is_pooling_model", False)
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195


_T = TypeVar("_T", bound=type[nn.Module])


def default_pooling_type(pooling_type: str):
    """Decorator to set `VllmModelForPooling.default_pooling_type`."""

    def func(model: _T) -> _T:
        model.default_pooling_type = pooling_type  # type: ignore
        return model

    return func


def get_default_pooling_type(model: Union[type[object], object]) -> str:
    return getattr(model, "default_pooling_type", "LAST")