interfaces.py 15.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
                    Protocol, Type, Union, overload, runtime_checkable)
5

6
import torch
7
from torch import Tensor
8
from typing_extensions import Self, TypeIs
9
10

from vllm.logger import init_logger
11
12
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
13
from vllm.utils import supports_kw
14

15
from .interfaces_base import is_pooling_model
16

17
if TYPE_CHECKING:
18
    from vllm.attention import AttentionMetadata
19
    from vllm.sequence import IntermediateTensors
20

21
22
23
24
if TYPE_CHECKING:
    from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
    from vllm.sequence import IntermediateTensors

25
26
logger = init_logger(__name__)

27
28
29
30
31
32
33
34
MultiModalEmbeddings = Union[list[Tensor], Tensor, tuple[Tensor, ...]]
"""
The output embeddings must be one of the following formats:

- A list or tuple of 2D tensors, where each tensor corresponds to
    each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
"""
35

36
37

@runtime_checkable
38
class SupportsMultiModal(Protocol):
39
    """The interface required for all multi-modal models."""
40

41
    supports_multimodal: ClassVar[Literal[True]] = True
42
    """
43
    A flag that indicates this model supports multi-modal inputs.
44
45
46
47
48

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
49

50
51
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
52
53
54
        """
        Returns multimodal embeddings generated from multimodal kwargs 
        to be merged with text embeddings.
55

56
        Note:
57
58
            The returned multimodal embeddings must be in the same order as
            the appearances of their corresponding multimodal data item in the
59
            input prompt.
60
61
62
        """
        ...

63
64
65
66
67
68
69
70
71
72
73
74
    def get_language_model(self) -> torch.nn.Module:
        """
        Returns the underlying language model used for text generation.

        This is typically the `torch.nn.Module` instance responsible for 
        processing the merged multimodal embeddings and producing hidden states

        Returns:
            torch.nn.Module: The core language model component.
        """
        ...

75
76
77
78
79
    # Only for models that support v0 chunked prefill
    # TODO(ywang96): Remove this overload once v0 is deprecated
    @overload
    def get_input_embeddings(
        self,
80
        input_ids: Tensor,
81
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
82
        attn_metadata: Optional["AttentionMetadata"] = None,
83
    ) -> Tensor:
84
85
        ...

86
    @overload
87
88
    def get_input_embeddings(
        self,
89
        input_ids: Tensor,
90
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
91
    ) -> Tensor:
92
93
94
95
96
        """
        Returns the input embeddings merged from the text embeddings from 
        input_ids and the multimodal embeddings generated from multimodal 
        kwargs.
        """
97
98
99
100
101
102
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
103
104
class _SupportsMultiModalType(Protocol):
    supports_multimodal: Literal[True]
105
106
107


@overload
108
109
def supports_multimodal(
        model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
110
111
112
113
    ...


@overload
114
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
115
116
117
    ...


118
def supports_multimodal(
119
    model: Union[Type[object], object],
120
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
121
    if isinstance(model, type):
122
        return isinstance(model, _SupportsMultiModalType)
123

124
    return isinstance(model, SupportsMultiModal)
125
126
127
128
129
130


@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

131
132
133
134
135
136
137
138
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
139
140
141
142
143
    # The `embedding_module` and `embedding_padding_modules`
    # are empty by default.
    embedding_modules: ClassVar[Dict[str, str]] = {}
    embedding_padding_modules: ClassVar[List[str]] = []
    packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
144
145
146
147
148
149
150
151
152
153
154
155
156
157


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

    packed_modules_mapping: Dict[str, List[str]]
    embedding_modules: Dict[str, str]
    embedding_padding_modules: List[str]


@overload
158
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
159
160
161
162
    ...


@overload
163
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
164
165
166
167
168
    ...


def supports_lora(
    model: Union[Type[object], object],
169
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "embedding_modules",
            "embedding_padding_modules",
        )
        missing_attrs = tuple(attr for attr in lora_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
                    "but does not set `supports_lora=True`.", model)

    return result


198
def _supports_lora(model: Union[Type[object], object]) -> bool:
199
200
201
202
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
203
204


205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
        *,
        intermediate_tensors: Optional["IntermediateTensors"],
231
    ) -> Union[Tensor, "IntermediateTensors"]:
232
233
234
235
236
237
        """
        Accept :class:`IntermediateTensors` when PP rank > 0.

        Return :class:`IntermediateTensors` only for the last PP rank.
        """
        ...
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        ...

    def forward(
        self,
256
        *,
257
        intermediate_tensors: Optional["IntermediateTensors"],
258
    ) -> Union[Tensor, "IntermediateTensors"]:
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
        ...


@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
    ...


@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
    ...


def supports_pp(
    model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
            "`intermediate_tensors` in its `forward` method", model)

    if not supports_attributes:
        pp_attrs = ("make_empty_intermediate_tensors", )
        missing_attrs = tuple(attr for attr in pp_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
                    "but does not set `supports_pp=True`.", model)

    return supports_attributes and supports_inspect


305
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
306
307
308
309
310
311
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


312
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
313
314
315
316
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

317
    return supports_kw(model_forward, "intermediate_tensors")
318
319


320
321
322
323
324
325
326
327
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
328
        for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
329
330
331
332
333
334
335
336
337
    """


@runtime_checkable
class _HasInnerStateType(Protocol):
    has_inner_state: ClassVar[Literal[True]]


@overload
338
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
339
340
341
342
    ...


@overload
343
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
344
345
346
347
348
    ...


def has_inner_state(
    model: Union[Type[object], object]
349
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
350
351
352
353
    if isinstance(model, type):
        return isinstance(model, _HasInnerStateType)

    return isinstance(model, HasInnerState)
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390


@runtime_checkable
class IsAttentionFree(Protocol):
    """The interface required for all models like Mamba that lack attention,
    but do have state whose size is constant wrt the number of tokens."""

    is_attention_free: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has no attention.
        Used for block manager and attention backend selection.
        True for Mamba but not Jamba.
    """


@runtime_checkable
class _IsAttentionFreeType(Protocol):
    is_attention_free: ClassVar[Literal[True]]


@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
    ...


@overload
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
    ...


def is_attention_free(
    model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
    if isinstance(model, type):
        return isinstance(model, _IsAttentionFreeType)

    return isinstance(model, IsAttentionFree)
391
392


393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
@runtime_checkable
class IsHybrid(Protocol):
    """The interface required for all models like Jamba that have both
    attention and mamba blocks, indicates that 
    hf_config has 'layers_block_type'"""

    is_hybrid: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has both mamba and attention blocks
        , also indicates that the model's hf_config has 
        'layers_block_type' """


@runtime_checkable
class _IsHybridType(Protocol):
    is_hybrid: ClassVar[Literal[True]]


@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
    ...


@overload
def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]:
    ...


def is_hybrid(
    model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]:
    if isinstance(model, type):
        return isinstance(model, _IsHybridType)

    return isinstance(model, IsHybrid)


430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
@runtime_checkable
class HasNoOps(Protocol):
    has_noops: ClassVar[Literal[True]] = True


@runtime_checkable
class _HasNoOpsType(Protocol):
    has_noops: ClassVar[Literal[True]]


@overload
def has_noops(model: object) -> TypeIs[HasNoOps]:
    ...


@overload
def has_noops(model: Type[object]) -> TypeIs[Type[HasNoOps]]:
    ...


def has_noops(
    model: Union[Type[object], object]
) -> Union[TypeIs[Type[HasNoOps]], TypeIs[HasNoOps]]:
    if isinstance(model, type):
        return isinstance(model, _HasNoOpsType)

    return isinstance(model, HasNoOps)


459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
@runtime_checkable
class SupportsCrossEncoding(Protocol):
    """The interface required for all models that support cross encoding."""

    supports_cross_encoding: ClassVar[Literal[True]] = True


@overload
def supports_cross_encoding(
        model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]:
    ...


@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
    ...


def _supports_cross_encoding(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:

    if isinstance(model, type):
        return isinstance(model, SupportsCrossEncoding)

    return isinstance(model, SupportsCrossEncoding)


def supports_cross_encoding(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
490
    return is_pooling_model(model) and _supports_cross_encoding(model)
491
492


493
494
495
496
497
498
class SupportsQuant:
    """The interface required for all models that support quantization."""

    packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
    quant_config: Optional[QuantizationConfig] = None

499
    def __new__(cls, *args, **kwargs) -> Self:
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
        instance = super().__new__(cls)
        quant_config = cls._find_quant_config(*args, **kwargs)
        if quant_config is not None:
            instance.quant_config = quant_config
            instance.quant_config.packed_modules_mapping.update(
                cls.packed_modules_mapping)
        return instance

    @staticmethod
    def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
        from vllm.config import VllmConfig  # avoid circular import

        args_values = list(args) + list(kwargs.values())
        for arg in args_values:
            if isinstance(arg, VllmConfig):
                return arg.quant_config

            if isinstance(arg, QuantizationConfig):
                return arg

        return None


523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
@runtime_checkable
class SupportsTranscription(Protocol):
    """The interface required for all models that support transcription."""

    supports_transcription: ClassVar[Literal[True]] = True


@overload
def supports_transcription(
        model: Type[object]) -> TypeIs[Type[SupportsTranscription]]:
    ...


@overload
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
    ...


def supports_transcription(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
    if isinstance(model, type):
        return isinstance(model, SupportsTranscription)

    return isinstance(model, SupportsTranscription)
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573


@runtime_checkable
class SupportsV0Only(Protocol):
    """Models with this interface are not compatible with V1 vLLM."""

    supports_v0_only: ClassVar[Literal[True]] = True


@overload
def supports_v0_only(model: Type[object]) -> TypeIs[Type[SupportsV0Only]]:
    ...


@overload
def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
    ...


def supports_v0_only(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
    if isinstance(model, type):
        return isinstance(model, SupportsV0Only)

    return isinstance(model, SupportsV0Only)