interfaces.py 19.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, MutableSequence
5
6
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
                    Union, overload, runtime_checkable)
7

8
import torch
9
from torch import Tensor
10
from typing_extensions import Self, TypeIs
11
12

from vllm.logger import init_logger
13
14
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
15
from vllm.utils import supports_kw
16

17
from .interfaces_base import is_pooling_model
18

19
if TYPE_CHECKING:
20
    from vllm.attention import AttentionMetadata
21
    from vllm.model_executor.models.utils import WeightsMapper
22
    from vllm.sequence import IntermediateTensors
23

24
25
26
27
if TYPE_CHECKING:
    from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
    from vllm.sequence import IntermediateTensors

28
29
logger = init_logger(__name__)

30
31
32
33
34
35
36
37
MultiModalEmbeddings = Union[list[Tensor], Tensor, tuple[Tensor, ...]]
"""
The output embeddings must be one of the following formats:

- A list or tuple of 2D tensors, where each tensor corresponds to
    each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
"""
38

39
40

@runtime_checkable
41
class SupportsMultiModal(Protocol):
42
    """The interface required for all multi-modal models."""
43

44
    supports_multimodal: ClassVar[Literal[True]] = True
45
    """
46
    A flag that indicates this model supports multi-modal inputs.
47
48
49
50
51

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
52

53
54
55
56
57
58
59
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        """
        Get the placeholder text for the `i`th `modality` item in the prompt.
        """
        ...

60
61
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
62
63
64
        """
        Returns multimodal embeddings generated from multimodal kwargs 
        to be merged with text embeddings.
65

66
        Note:
67
68
            The returned multimodal embeddings must be in the same order as
            the appearances of their corresponding multimodal data item in the
69
            input prompt.
70
71
72
        """
        ...

73
74
75
76
77
78
79
80
81
82
83
84
    def get_language_model(self) -> torch.nn.Module:
        """
        Returns the underlying language model used for text generation.

        This is typically the `torch.nn.Module` instance responsible for 
        processing the merged multimodal embeddings and producing hidden states

        Returns:
            torch.nn.Module: The core language model component.
        """
        ...

85
86
87
88
89
    # Only for models that support v0 chunked prefill
    # TODO(ywang96): Remove this overload once v0 is deprecated
    @overload
    def get_input_embeddings(
        self,
90
        input_ids: Tensor,
91
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
92
        attn_metadata: Optional["AttentionMetadata"] = None,
93
    ) -> Tensor:
94
95
        ...

96
    @overload
97
98
    def get_input_embeddings(
        self,
99
        input_ids: Tensor,
100
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
101
    ) -> Tensor:
102
103
104
105
106
        """
        Returns the input embeddings merged from the text embeddings from 
        input_ids and the multimodal embeddings generated from multimodal 
        kwargs.
        """
107
108
109
110
111
112
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
113
114
class _SupportsMultiModalType(Protocol):
    supports_multimodal: Literal[True]
115
116
117


@overload
118
def supports_multimodal(
119
        model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
120
121
122
123
    ...


@overload
124
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
125
126
127
    ...


128
def supports_multimodal(
129
130
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
131
    if isinstance(model, type):
132
        return isinstance(model, _SupportsMultiModalType)
133

134
    return isinstance(model, SupportsMultiModal)
135
136
137
138
139
140


@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

141
142
143
144
145
146
147
148
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
149
150
    # The `embedding_module` and `embedding_padding_modules`
    # are empty by default.
151
152
153
    embedding_modules: ClassVar[dict[str, str]] = {}
    embedding_padding_modules: ClassVar[list[str]] = []
    packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
154
155
156
157
158
159
160
161


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

162
163
164
    packed_modules_mapping: dict[str, list[str]]
    embedding_modules: dict[str, str]
    embedding_padding_modules: list[str]
165
166
167


@overload
168
def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]:
169
170
171
172
    ...


@overload
173
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
174
175
176
177
    ...


def supports_lora(
178
179
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "embedding_modules",
            "embedding_padding_modules",
        )
        missing_attrs = tuple(attr for attr in lora_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
                    "but does not set `supports_lora=True`.", model)

    return result


208
def _supports_lora(model: Union[type[object], object]) -> bool:
209
210
211
212
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
213
214


215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
        *,
        intermediate_tensors: Optional["IntermediateTensors"],
241
    ) -> Union[Tensor, "IntermediateTensors"]:
242
        """
243
244
        Accept [`IntermediateTensors`][vllm.sequence.IntermediateTensors] when
        PP rank > 0.
245

246
247
        Return [`IntermediateTensors`][vllm.sequence.IntermediateTensors] only
        for the last PP rank.
248
249
        """
        ...
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        ...

    def forward(
        self,
268
        *,
269
        intermediate_tensors: Optional["IntermediateTensors"],
270
    ) -> Union[Tensor, "IntermediateTensors"]:
271
272
273
274
        ...


@overload
275
def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]:
276
277
278
279
280
281
282
283
284
    ...


@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
    ...


def supports_pp(
285
286
    model: Union[type[object], object],
) -> Union[bool, TypeIs[type[SupportsPP]], TypeIs[SupportsPP]]:
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
            "`intermediate_tensors` in its `forward` method", model)

    if not supports_attributes:
        pp_attrs = ("make_empty_intermediate_tensors", )
        missing_attrs = tuple(attr for attr in pp_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
                    "but does not set `supports_pp=True`.", model)

    return supports_attributes and supports_inspect


317
def _supports_pp_attributes(model: Union[type[object], object]) -> bool:
318
319
320
321
322
323
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


324
def _supports_pp_inspect(model: Union[type[object], object]) -> bool:
325
326
327
328
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

329
    return supports_kw(model_forward, "intermediate_tensors")
330
331


332
333
334
335
336
337
338
339
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
340
        for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
341
342
343
344
345
346
347
348
349
    """


@runtime_checkable
class _HasInnerStateType(Protocol):
    has_inner_state: ClassVar[Literal[True]]


@overload
350
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
351
352
353
354
    ...


@overload
355
def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
356
357
358
359
    ...


def has_inner_state(
360
361
    model: Union[type[object], object]
) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]:
362
363
364
365
    if isinstance(model, type):
        return isinstance(model, _HasInnerStateType)

    return isinstance(model, HasInnerState)
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391


@runtime_checkable
class IsAttentionFree(Protocol):
    """The interface required for all models like Mamba that lack attention,
    but do have state whose size is constant wrt the number of tokens."""

    is_attention_free: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has no attention.
        Used for block manager and attention backend selection.
        True for Mamba but not Jamba.
    """


@runtime_checkable
class _IsAttentionFreeType(Protocol):
    is_attention_free: ClassVar[Literal[True]]


@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
    ...


@overload
392
def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
393
394
395
396
    ...


def is_attention_free(
397
398
    model: Union[type[object], object]
) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
399
400
401
402
    if isinstance(model, type):
        return isinstance(model, _IsAttentionFreeType)

    return isinstance(model, IsAttentionFree)
403
404


405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
@runtime_checkable
class IsHybrid(Protocol):
    """The interface required for all models like Jamba that have both
    attention and mamba blocks, indicates that 
    hf_config has 'layers_block_type'"""

    is_hybrid: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has both mamba and attention blocks
        , also indicates that the model's hf_config has 
        'layers_block_type' """


@runtime_checkable
class _IsHybridType(Protocol):
    is_hybrid: ClassVar[Literal[True]]


@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
    ...


@overload
429
def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
430
431
432
433
    ...


def is_hybrid(
434
435
    model: Union[type[object], object]
) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]:
436
437
438
439
440
441
    if isinstance(model, type):
        return isinstance(model, _IsHybridType)

    return isinstance(model, IsHybrid)


442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
@runtime_checkable
class MixtureOfExperts(Protocol):
    """
    Check if the model is a mixture of experts (MoE) model.
    """

    expert_weights: MutableSequence[Iterable[Tensor]]
    """
    Expert weights saved in this rank.

    The first dimension is the layer, and the second dimension is different
    parameters in the layer, e.g. up/down projection weights.
    """

    num_moe_layers: int
    """Number of MoE layers in this model."""

    num_expert_groups: int
    """Number of expert groups in this model."""

    num_logical_experts: int
    """Number of logical experts in this model."""

    num_physical_experts: int
    """Number of physical experts in this model."""

    num_local_physical_experts: int
    """Number of local physical experts in this model."""

    num_routed_experts: int
    """Number of routed experts in this model."""

    num_shared_experts: int
    """Number of shared experts in this model."""

    num_redundant_experts: int
    """Number of redundant experts in this model."""

    def set_eplb_state(
        self,
        expert_load_view: Tensor,
        logical_to_physical_map: Tensor,
        logical_replica_count: Tensor,
    ) -> None:
        """
        Register the EPLB state in the MoE model.
        
        Since these are views of the actual EPLB state, any changes made by
        the EPLB algorithm are automatically reflected in the model's behavior
        without requiring additional method calls to set new states.

        You should also collect model's `expert_weights` here instead of in
        the weight loader, since after initial weight loading, further
        processing like quantization may be applied to the weights.

        Args:
            expert_load_view: A view of the expert load metrics tensor.
            logical_to_physical_map: Mapping from logical to physical experts.
            logical_replica_count: Count of replicas for each logical expert.
        """
        ...


def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
    return isinstance(model, MixtureOfExperts)


509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
@runtime_checkable
class HasNoOps(Protocol):
    has_noops: ClassVar[Literal[True]] = True


@runtime_checkable
class _HasNoOpsType(Protocol):
    has_noops: ClassVar[Literal[True]]


@overload
def has_noops(model: object) -> TypeIs[HasNoOps]:
    ...


@overload
525
def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
526
527
528
529
    ...


def has_noops(
530
531
    model: Union[type[object], object]
) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]:
532
533
534
535
536
537
    if isinstance(model, type):
        return isinstance(model, _HasNoOpsType)

    return isinstance(model, HasNoOps)


538
539
540
541
542
543
544
545
546
@runtime_checkable
class SupportsCrossEncoding(Protocol):
    """The interface required for all models that support cross encoding."""

    supports_cross_encoding: ClassVar[Literal[True]] = True


@overload
def supports_cross_encoding(
547
        model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]:
548
549
550
551
552
553
554
555
556
    ...


@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
    ...


def _supports_cross_encoding(
557
558
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
559
560
561
562
563
564
565
566

    if isinstance(model, type):
        return isinstance(model, SupportsCrossEncoding)

    return isinstance(model, SupportsCrossEncoding)


def supports_cross_encoding(
567
568
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
569
    return is_pooling_model(model) and _supports_cross_encoding(model)
570
571


572
573
574
575
576
577
def has_step_pooler(model: Union[type[object], object]) -> bool:
    """Check if the model uses step pooler."""
    return is_pooling_model(model) and any(
        type(module).__name__ == "StepPool" for module in model.modules())


578
579
580
class SupportsQuant:
    """The interface required for all models that support quantization."""

581
582
    hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None
    packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None
583
584
    quant_config: Optional[QuantizationConfig] = None

585
    def __new__(cls, *args, **kwargs) -> Self:
586
        instance = super().__new__(cls)
587
588

        # find config passed in arguments
589
590
        quant_config = cls._find_quant_config(*args, **kwargs)
        if quant_config is not None:
591
592

            # attach config to model for general use
593
            instance.quant_config = quant_config
594
595
596
597
598
599
600
601
602
603
604
605

            # apply model mappings to config for proper config-model matching
            # NOTE: `TransformersForCausalLM` is not supported due to how this
            # class defines `hf_to_vllm_mapper` as a post-init `@property`.
            # After this is fixed, get `instance.hf_to_vllm_mapper` directly
            if getattr(instance, "hf_to_vllm_mapper", None) is not None:
                instance.quant_config.apply_vllm_mapper(
                    instance.hf_to_vllm_mapper)
            if getattr(instance, "packed_modules_mapping", None) is not None:
                instance.quant_config.packed_modules_mapping.update(
                    instance.packed_modules_mapping)

606
607
608
609
        return instance

    @staticmethod
    def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
610
        """Find quant config passed through model constructor args"""
611
612
613
614
615
616
617
618
619
620
621
622
623
        from vllm.config import VllmConfig  # avoid circular import

        args_values = list(args) + list(kwargs.values())
        for arg in args_values:
            if isinstance(arg, VllmConfig):
                return arg.quant_config

            if isinstance(arg, QuantizationConfig):
                return arg

        return None


624
625
626
627
628
629
@runtime_checkable
class SupportsTranscription(Protocol):
    """The interface required for all models that support transcription."""

    supports_transcription: ClassVar[Literal[True]] = True

630
631
632
633
634
635
636
637
638
639
640
    @classmethod
    def get_decoder_prompt(cls, language: str, task_type: str,
                           prompt: str) -> str:
        """Get the decoder prompt for the ASR model."""
        ...

    @classmethod
    def validate_language(cls, language: str) -> bool:
        """Check if the model supports a specific ISO639_1 language."""
        ...

641
642
643

@overload
def supports_transcription(
644
        model: type[object]) -> TypeIs[type[SupportsTranscription]]:
645
646
647
648
649
650
651
652
653
    ...


@overload
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
    ...


def supports_transcription(
654
655
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
656
657
658
659
    if isinstance(model, type):
        return isinstance(model, SupportsTranscription)

    return isinstance(model, SupportsTranscription)
660
661
662
663
664
665
666
667
668
669


@runtime_checkable
class SupportsV0Only(Protocol):
    """Models with this interface are not compatible with V1 vLLM."""

    supports_v0_only: ClassVar[Literal[True]] = True


@overload
670
def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]:
671
672
673
674
675
676
677
678
679
    ...


@overload
def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
    ...


def supports_v0_only(
680
681
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
682
683
684
685
    if isinstance(model, type):
        return isinstance(model, SupportsV0Only)

    return isinstance(model, SupportsV0Only)