interfaces.py 11 KB
Newer Older
1
2
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
                    Protocol, Type, Union, overload, runtime_checkable)
3

4
import torch
5
from typing_extensions import TypeIs, TypeVar
6
7

from vllm.logger import init_logger
8
from vllm.utils import supports_kw
9

10
from .interfaces_base import is_pooling_model
11

12
if TYPE_CHECKING:
13
14
    from vllm.attention import AttentionMetadata
    from vllm.multimodal.inputs import NestedTensors  # noqa: F401
15
16
    from vllm.sequence import IntermediateTensors

17
18
logger = init_logger(__name__)

19
20
T = TypeVar("T", default="NestedTensors")

21
22

@runtime_checkable
23
class SupportsMultiModal(Protocol):
24
    """The interface required for all multi-modal models."""
25

26
    supports_multimodal: ClassVar[Literal[True]] = True
27
    """
28
    A flag that indicates this model supports multi-modal inputs.
29
30
31
32
33

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
        """
        Returns multimodal embeddings generated from multimodal kwargs 
        to be merged with text embeddings.
        """
        ...

    # Only for models that support v0 chunked prefill
    # TODO(ywang96): Remove this overload once v0 is deprecated
    @overload
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[T] = None,
        attn_metadata: Optional["AttentionMetadata"] = None,
    ) -> torch.Tensor:
        ...

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[T] = None,
    ) -> torch.Tensor:
        """
        Returns the input embeddings merged from the text embeddings from 
        input_ids and the multimodal embeddings generated from multimodal 
        kwargs.
        """
        ...

65
66
67
68

# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
69
70
class _SupportsMultiModalType(Protocol):
    supports_multimodal: Literal[True]
71
72
73


@overload
74
75
def supports_multimodal(
        model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
76
77
78
79
    ...


@overload
80
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
81
82
83
    ...


84
def supports_multimodal(
85
    model: Union[Type[object], object],
86
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
87
    if isinstance(model, type):
88
        return isinstance(model, _SupportsMultiModalType)
89

90
    return isinstance(model, SupportsMultiModal)
91
92
93
94
95
96


@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

97
98
99
100
101
102
103
104
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

    packed_modules_mapping: ClassVar[Dict[str, List[str]]]
    supported_lora_modules: ClassVar[List[str]]
    embedding_modules: ClassVar[Dict[str, str]]
    embedding_padding_modules: ClassVar[List[str]]


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

    packed_modules_mapping: Dict[str, List[str]]
    supported_lora_modules: List[str]
    embedding_modules: Dict[str, str]
    embedding_padding_modules: List[str]


@overload
125
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
126
127
128
129
    ...


@overload
130
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
131
132
133
134
135
    ...


def supports_lora(
    model: Union[Type[object], object],
136
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "supported_lora_modules",
            "embedding_modules",
            "embedding_padding_modules",
        )
        missing_attrs = tuple(attr for attr in lora_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
                    "but does not set `supports_lora=True`.", model)

    return result


166
def _supports_lora(model: Union[Type[object], object]) -> bool:
167
168
169
170
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
171
172


173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
197
        *,
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        intermediate_tensors: Optional["IntermediateTensors"],
    ) -> Union[torch.Tensor, "IntermediateTensors"]:
        """
        Accept :class:`IntermediateTensors` when PP rank > 0.

        Return :class:`IntermediateTensors` only for the last PP rank.
        """
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        ...

    def forward(
        self,
224
        *,
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        intermediate_tensors: Optional["IntermediateTensors"],
    ) -> Union[torch.Tensor, "IntermediateTensors"]:
        ...


@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
    ...


@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
    ...


def supports_pp(
    model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
            "`intermediate_tensors` in its `forward` method", model)

    if not supports_attributes:
        pp_attrs = ("make_empty_intermediate_tensors", )
        missing_attrs = tuple(attr for attr in pp_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
                    "but does not set `supports_pp=True`.", model)

    return supports_attributes and supports_inspect


273
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
274
275
276
277
278
279
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


280
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
281
282
283
284
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

285
    return supports_kw(model_forward, "intermediate_tensors")
286
287


288
289
290
291
292
293
294
295
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
296
        for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
297
298
299
300
301
302
303
304
305
    """


@runtime_checkable
class _HasInnerStateType(Protocol):
    has_inner_state: ClassVar[Literal[True]]


@overload
306
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
307
308
309
310
    ...


@overload
311
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
312
313
314
315
316
    ...


def has_inner_state(
    model: Union[Type[object], object]
317
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
318
319
320
321
    if isinstance(model, type):
        return isinstance(model, _HasInnerStateType)

    return isinstance(model, HasInnerState)
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358


@runtime_checkable
class IsAttentionFree(Protocol):
    """The interface required for all models like Mamba that lack attention,
    but do have state whose size is constant wrt the number of tokens."""

    is_attention_free: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has no attention.
        Used for block manager and attention backend selection.
        True for Mamba but not Jamba.
    """


@runtime_checkable
class _IsAttentionFreeType(Protocol):
    is_attention_free: ClassVar[Literal[True]]


@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
    ...


@overload
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
    ...


def is_attention_free(
    model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
    if isinstance(model, type):
        return isinstance(model, _IsAttentionFreeType)

    return isinstance(model, IsAttentionFree)
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391


@runtime_checkable
class SupportsCrossEncoding(Protocol):
    """The interface required for all models that support cross encoding."""

    supports_cross_encoding: ClassVar[Literal[True]] = True


@overload
def supports_cross_encoding(
        model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]:
    ...


@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
    ...


def _supports_cross_encoding(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:

    if isinstance(model, type):
        return isinstance(model, SupportsCrossEncoding)

    return isinstance(model, SupportsCrossEncoding)


def supports_cross_encoding(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
392
    return is_pooling_model(model) and _supports_cross_encoding(model)