interfaces.py 9.23 KB
Newer Older
1
2
3
import inspect
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
                    Protocol, Type, Union, overload, runtime_checkable)
4

5
import torch
6
from typing_extensions import TypeIs
7
8
9

from vllm.logger import init_logger

10
11
12
13
14
if TYPE_CHECKING:
    from vllm.attention import AttentionMetadata
    from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
    from vllm.sequence import IntermediateTensors

15
16
17
18
logger = init_logger(__name__)


@runtime_checkable
19
class SupportsMultiModal(Protocol):
20
    """The interface required for all multi-modal models."""
21

22
    supports_multimodal: ClassVar[Literal[True]] = True
23
    """
24
    A flag that indicates this model supports multi-modal inputs.
25
26
27
28
29

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
30

31
    def __init__(self, *, multimodal_config: "MultiModalConfig") -> None:
32
33
34
35
36
37
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
38
39
class _SupportsMultiModalType(Protocol):
    supports_multimodal: Literal[True]
40

41
    def __call__(self, *, multimodal_config: "MultiModalConfig") -> None:
42
43
44
45
        ...


@overload
46
47
def supports_multimodal(
        model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
48
49
50
51
    ...


@overload
52
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
53
54
55
    ...


56
def supports_multimodal(
57
    model: Union[Type[object], object],
58
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
59
    if isinstance(model, type):
60
        return isinstance(model, _SupportsMultiModalType)
61

62
    return isinstance(model, SupportsMultiModal)
63
64
65
66
67
68


@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

69
70
71
72
73
74
75
76
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
77
78
79
80
81
82
83

    packed_modules_mapping: ClassVar[Dict[str, List[str]]]
    supported_lora_modules: ClassVar[List[str]]
    embedding_modules: ClassVar[Dict[str, str]]
    embedding_padding_modules: ClassVar[List[str]]

    # lora_config is None when LoRA is not enabled
84
    def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

    packed_modules_mapping: Dict[str, List[str]]
    supported_lora_modules: List[str]
    embedding_modules: Dict[str, str]
    embedding_padding_modules: List[str]

99
    def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
100
101
102
103
        ...


@overload
104
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
105
106
107
108
    ...


@overload
109
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
110
111
112
113
114
    ...


def supports_lora(
    model: Union[Type[object], object],
115
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "supported_lora_modules",
            "embedding_modules",
            "embedding_padding_modules",
        )
        missing_attrs = tuple(attr for attr in lora_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
                    "but does not set `supports_lora=True`.", model)

    return result


def _supports_lora(
    model: Union[Type[object], object],
147
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
148
149
150
151
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
152
153


154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: "AttentionMetadata",
        intermediate_tensors: Optional["IntermediateTensors"],
    ) -> Union[torch.Tensor, "IntermediateTensors"]:
        """
        Accept :class:`IntermediateTensors` when PP rank > 0.

        Return :class:`IntermediateTensors` only for the last PP rank.
        """
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        position_ids: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: "AttentionMetadata",
        intermediate_tensors: Optional["IntermediateTensors"],
    ) -> Union[torch.Tensor, "IntermediateTensors"]:
        ...


@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
    ...


@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
    ...


def supports_pp(
    model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
            "`intermediate_tensors` in its `forward` method", model)

    if not supports_attributes:
        pp_attrs = ("make_empty_intermediate_tensors", )
        missing_attrs = tuple(attr for attr in pp_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
                    "but does not set `supports_pp=True`.", model)

    return supports_attributes and supports_inspect


def _supports_pp_attributes(
    model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


def _supports_pp_inspect(
    model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

    forward_params = inspect.signature(model_forward).parameters
    return "intermediate_tensors" in forward_params


280
281
282
283
284
285
286
287
288
289
290
291
292
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
        for max_num_seqs ,etc... (Currently only used by Jamba)
    """

    def __init__(self,
                 *,
293
                 scheduler_config: Optional["SchedulerConfig"] = None) -> None:
294
295
296
297
298
299
300
301
302
        ...


@runtime_checkable
class _HasInnerStateType(Protocol):
    has_inner_state: ClassVar[Literal[True]]

    def __init__(self,
                 *,
303
                 scheduler_config: Optional["SchedulerConfig"] = None) -> None:
304
305
306
307
        ...


@overload
308
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
309
310
311
312
    ...


@overload
313
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
314
315
316
317
318
    ...


def has_inner_state(
    model: Union[Type[object], object]
319
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
320
321
322
323
    if isinstance(model, type):
        return isinstance(model, _HasInnerStateType)

    return isinstance(model, HasInnerState)