interfaces_base.py 5.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
from typing import (TYPE_CHECKING, List, Optional, Protocol, Type, Union,
                    overload, runtime_checkable)

import torch
import torch.nn as nn
from transformers import PretrainedConfig
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
from vllm.utils import supports_kw

if TYPE_CHECKING:
    from vllm.attention import AttentionMetadata
    from vllm.config import CacheConfig
    from vllm.model_executor.layers.pooler import PoolerOutput
    from vllm.model_executor.layers.quantization import QuantizationConfig
    from vllm.model_executor.layers.sampler import SamplerOutput
    from vllm.model_executor.pooling_metadata import PoolingMetadata
    from vllm.model_executor.sampling_metadata import SamplingMetadata

logger = init_logger(__name__)

# The type of HF config
C_co = TypeVar("C_co", bound=PretrainedConfig, covariant=True)

# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
# which has T = List[torch.Tensor]
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)

# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes


@runtime_checkable
class VllmModel(Protocol[C_co, T_co]):

    def __init__(
        self,
        config: C_co,
        *,
        cache_config: Optional["CacheConfig"],
        quant_config: Optional["QuantizationConfig"],
    ) -> None:
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: "AttentionMetadata",
    ) -> T_co:
        ...


def _check_vllm_model_init(model: Union[Type[object], object]) -> bool:
    model_init = model.__init__
    vllm_kws = ("cache_config", "quant_config")
    missing_kws = tuple(kw for kw in vllm_kws
                        if not supports_kw(model_init, kw))

    if missing_kws and (isinstance(model, type)
                        and issubclass(model, nn.Module)):
        logger.warning(
            "The model (%s) is missing "
            "vLLM-specific keywords from its initializer: %s",
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


def _check_vllm_model_forward(model: Union[Type[object], object]) -> bool:
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

    vllm_kws = ("input_ids", "positions", "kv_caches", "attn_metadata")
    missing_kws = tuple(kw for kw in vllm_kws
                        if not supports_kw(model_forward, kw))

    if missing_kws and (isinstance(model, type)
                        and issubclass(model, nn.Module)):
        logger.warning(
            "The model (%s) is missing "
            "vLLM-specific keywords from its initializer: %s",
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


@overload
def is_vllm_model(model: Type[object]) -> TypeIs[Type[VllmModel]]:
    ...


@overload
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
    ...


def is_vllm_model(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModel]], TypeIs[VllmModel]]:
    return _check_vllm_model_init(model) and _check_vllm_model_forward(model)


@runtime_checkable
class VllmModelForTextGeneration(VllmModel[C_co, T], Protocol[C_co, T]):

    def compute_logits(
        self,
        hidden_states: T,
        sampling_metadata: "SamplingMetadata",
    ) -> Optional[T]:
        """Return `None` if TP rank > 0."""
        ...

    def sample(
        self,
        logits: T,
        sampling_metadata: "SamplingMetadata",
    ) -> "SamplerOutput":
        """Only called on TP rank 0."""
        ...


@overload
def is_text_generation_model(
        model: Type[object]) -> TypeIs[Type[VllmModelForTextGeneration]]:
    ...


@overload
def is_text_generation_model(
        model: object) -> TypeIs[VllmModelForTextGeneration]:
    ...


def is_text_generation_model(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForTextGeneration]],
           TypeIs[VllmModelForTextGeneration]]:
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)


@runtime_checkable
class VllmModelForEmbedding(VllmModel[C_co, T], Protocol[C_co, T]):

    def pooler(
        self,
        hidden_states: T,
        pooling_metadata: "PoolingMetadata",
    ) -> "PoolerOutput":
        """Only called on TP rank 0."""
        ...


@overload
def is_embedding_model(
        model: Type[object]) -> TypeIs[Type[VllmModelForEmbedding]]:
    ...


@overload
def is_embedding_model(model: object) -> TypeIs[VllmModelForEmbedding]:
    ...


def is_embedding_model(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[VllmModelForEmbedding]], TypeIs[VllmModelForEmbedding]]:
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForEmbedding)

    return isinstance(model, VllmModelForEmbedding)