interfaces.py 21.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, MutableSequence
5
6
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
                    Union, overload, runtime_checkable)
7

8
import numpy as np
9
import torch
10
from torch import Tensor
11
from typing_extensions import Self, TypeIs
12

13
from vllm.config import ModelConfig, SpeechToTextConfig
14
from vllm.inputs import TokensPrompt
15
from vllm.inputs.data import PromptType
16
from vllm.logger import init_logger
17
18
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
19
from vllm.utils import supports_kw
20

21
from .interfaces_base import is_pooling_model
22

23
if TYPE_CHECKING:
24
    from vllm.attention import AttentionMetadata
25
    from vllm.config import VllmConfig
26
    from vllm.model_executor.models.utils import WeightsMapper
27
28
    from vllm.sequence import IntermediateTensors

29
30
logger = init_logger(__name__)

31
32
33
34
35
36
37
38
MultiModalEmbeddings = Union[list[Tensor], Tensor, tuple[Tensor, ...]]
"""
The output embeddings must be one of the following formats:

- A list or tuple of 2D tensors, where each tensor corresponds to
    each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
"""
39

40
41

@runtime_checkable
42
class SupportsMultiModal(Protocol):
43
    """The interface required for all multi-modal models."""
44

45
    supports_multimodal: ClassVar[Literal[True]] = True
46
    """
47
    A flag that indicates this model supports multi-modal inputs.
48
49
50
51
52

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
53

54
55
56
57
58
59
60
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        """
        Get the placeholder text for the `i`th `modality` item in the prompt.
        """
        ...

61
62
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
63
64
65
        """
        Returns multimodal embeddings generated from multimodal kwargs 
        to be merged with text embeddings.
66

67
        Note:
68
69
            The returned multimodal embeddings must be in the same order as
            the appearances of their corresponding multimodal data item in the
70
            input prompt.
71
72
73
        """
        ...

74
75
76
77
78
79
80
81
82
83
84
85
    def get_language_model(self) -> torch.nn.Module:
        """
        Returns the underlying language model used for text generation.

        This is typically the `torch.nn.Module` instance responsible for 
        processing the merged multimodal embeddings and producing hidden states

        Returns:
            torch.nn.Module: The core language model component.
        """
        ...

86
87
88
89
90
    # Only for models that support v0 chunked prefill
    # TODO(ywang96): Remove this overload once v0 is deprecated
    @overload
    def get_input_embeddings(
        self,
91
        input_ids: Tensor,
92
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
93
        attn_metadata: Optional["AttentionMetadata"] = None,
94
    ) -> Tensor:
95
96
        ...

97
    # TODO: Remove this overload once v0 is deprecated
98
    @overload
99
100
    def get_input_embeddings(
        self,
101
        input_ids: Tensor,
102
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
103
104
105
106
107
108
109
110
111
112
    ) -> Tensor:
        ...

    def get_input_embeddings(
        self,
        input_ids: Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
        # Only necessary so that the v0 overload is valid
        # TODO: Remove attn_metadata once v0 is deprecated
        attn_metadata: Optional["AttentionMetadata"] = None,
113
    ) -> Tensor:
114
115
116
117
118
119
120
        """
        Returns the input embeddings merged from the text embeddings from 
        input_ids and the multimodal embeddings generated from multimodal 
        kwargs.
        """
        ...

121
122

@overload
123
def supports_multimodal(
124
        model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
125
126
127
128
    ...


@overload
129
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
130
131
132
    ...


133
def supports_multimodal(
134
135
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
136
    return getattr(model, "supports_multimodal", False)
137
138


139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@runtime_checkable
class SupportsScoreTemplate(Protocol):
    """The interface required for all models that support score template."""

    supports_score_template: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports score template.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    @classmethod
    def get_score_template(cls, query: str, document: str) -> Optional[str]:
        """
        Generate a full prompt by populating the score template with query and document content.
        """ # noqa: E501
        ...

    @classmethod
    def post_process_tokens(cls, prompt: TokensPrompt) -> None:
        """
        Perform architecture-specific manipulations on the input tokens.
        """
        ...


@overload
def supports_score_template(
        model: type[object]) -> TypeIs[type[SupportsScoreTemplate]]:
    ...


@overload
def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]:
    ...


def supports_score_template(
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsScoreTemplate]], TypeIs[SupportsScoreTemplate]]:
181
    return getattr(model, "supports_score_template", False)
182
183


184
185
186
187
@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

188
189
190
191
192
193
194
195
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
196
197
    # The `embedding_module` and `embedding_padding_modules`
    # are empty by default.
198
199
200
    embedding_modules: ClassVar[dict[str, str]] = {}
    embedding_padding_modules: ClassVar[list[str]] = []
    packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
201
202
203
204
205
206
207
208


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

209
210
211
    packed_modules_mapping: dict[str, list[str]]
    embedding_modules: dict[str, str]
    embedding_padding_modules: list[str]
212
213
214


@overload
215
def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]:
216
217
218
219
    ...


@overload
220
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
221
222
223
224
    ...


def supports_lora(
225
226
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "embedding_modules",
            "embedding_padding_modules",
        )
        missing_attrs = tuple(attr for attr in lora_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
                    "but does not set `supports_lora=True`.", model)

    return result


255
def _supports_lora(model: Union[type[object], object]) -> bool:
256
257
258
259
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
260
261


262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
286
        *,
287
        intermediate_tensors: Optional["IntermediateTensors"],
288
    ) -> Union[Tensor, "IntermediateTensors"]:
289
        """
290
291
        Accept [`IntermediateTensors`][vllm.sequence.IntermediateTensors] when
        PP rank > 0.
292

293
294
        Return [`IntermediateTensors`][vllm.sequence.IntermediateTensors] only
        for the last PP rank.
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        """
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        ...

    def forward(
        self,
315
        *,
316
        intermediate_tensors: Optional["IntermediateTensors"],
317
    ) -> Union[Tensor, "IntermediateTensors"]:
318
319
320
321
        ...


@overload
322
def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]:
323
324
325
326
327
328
329
330
331
    ...


@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
    ...


def supports_pp(
332
333
    model: Union[type[object], object],
) -> Union[bool, TypeIs[type[SupportsPP]], TypeIs[SupportsPP]]:
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
            "`intermediate_tensors` in its `forward` method", model)

    if not supports_attributes:
        pp_attrs = ("make_empty_intermediate_tensors", )
        missing_attrs = tuple(attr for attr in pp_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
                    "but does not set `supports_pp=True`.", model)

    return supports_attributes and supports_inspect


364
def _supports_pp_attributes(model: Union[type[object], object]) -> bool:
365
366
367
368
369
370
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


371
def _supports_pp_inspect(model: Union[type[object], object]) -> bool:
372
373
374
375
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

376
    return supports_kw(model_forward, "intermediate_tensors")
377
378


379
380
381
382
383
384
385
386
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
387
        for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
388
389
390
391
    """


@overload
392
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
393
394
395
396
    ...


@overload
397
def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
398
399
400
401
    ...


def has_inner_state(
402
403
    model: Union[type[object], object]
) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]:
404
    return getattr(model, "has_inner_state", False)
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425


@runtime_checkable
class IsAttentionFree(Protocol):
    """The interface required for all models like Mamba that lack attention,
    but do have state whose size is constant wrt the number of tokens."""

    is_attention_free: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has no attention.
        Used for block manager and attention backend selection.
        True for Mamba but not Jamba.
    """


@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
    ...


@overload
426
def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
427
428
429
430
    ...


def is_attention_free(
431
432
    model: Union[type[object], object]
) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
433
    return getattr(model, "is_attention_free", False)
434
435


436
437
438
439
440
441
442
443
444
445
446
447
@runtime_checkable
class IsHybrid(Protocol):
    """The interface required for all models like Jamba that have both
    attention and mamba blocks, indicates that 
    hf_config has 'layers_block_type'"""

    is_hybrid: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has both mamba and attention blocks
        , also indicates that the model's hf_config has 
        'layers_block_type' """

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
        use_v1: bool = True,
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config
            use_v1: Get shapes for V1 (or V0)

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        ...

467
468
469
470
471
472
473

@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
    ...


@overload
474
def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
475
476
477
478
    ...


def is_hybrid(
479
480
    model: Union[type[object], object]
) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]:
481
    return getattr(model, "is_hybrid", False)
482
483


484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
@runtime_checkable
class MixtureOfExperts(Protocol):
    """
    Check if the model is a mixture of experts (MoE) model.
    """

    expert_weights: MutableSequence[Iterable[Tensor]]
    """
    Expert weights saved in this rank.

    The first dimension is the layer, and the second dimension is different
    parameters in the layer, e.g. up/down projection weights.
    """

    num_moe_layers: int
    """Number of MoE layers in this model."""

    num_expert_groups: int
    """Number of expert groups in this model."""

    num_logical_experts: int
    """Number of logical experts in this model."""

    num_physical_experts: int
    """Number of physical experts in this model."""

    num_local_physical_experts: int
    """Number of local physical experts in this model."""

    num_routed_experts: int
    """Number of routed experts in this model."""

    num_shared_experts: int
    """Number of shared experts in this model."""

    num_redundant_experts: int
    """Number of redundant experts in this model."""

    def set_eplb_state(
        self,
        expert_load_view: Tensor,
        logical_to_physical_map: Tensor,
        logical_replica_count: Tensor,
    ) -> None:
        """
        Register the EPLB state in the MoE model.
        
        Since these are views of the actual EPLB state, any changes made by
        the EPLB algorithm are automatically reflected in the model's behavior
        without requiring additional method calls to set new states.

        You should also collect model's `expert_weights` here instead of in
        the weight loader, since after initial weight loading, further
        processing like quantization may be applied to the weights.

        Args:
            expert_load_view: A view of the expert load metrics tensor.
            logical_to_physical_map: Mapping from logical to physical experts.
            logical_replica_count: Count of replicas for each logical expert.
        """
        ...


def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
    return isinstance(model, MixtureOfExperts)


551
552
553
554
555
556
557
558
559
560
561
@runtime_checkable
class HasNoOps(Protocol):
    has_noops: ClassVar[Literal[True]] = True


@overload
def has_noops(model: object) -> TypeIs[HasNoOps]:
    ...


@overload
562
def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
563
564
565
566
    ...


def has_noops(
567
568
    model: Union[type[object], object]
) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]:
569
    return getattr(model, "has_noops", False)
570
571


572
573
574
575
576
577
578
579
580
@runtime_checkable
class SupportsCrossEncoding(Protocol):
    """The interface required for all models that support cross encoding."""

    supports_cross_encoding: ClassVar[Literal[True]] = True


@overload
def supports_cross_encoding(
581
        model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]:
582
583
584
585
586
587
588
589
590
    ...


@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
    ...


def _supports_cross_encoding(
591
592
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
593
    return getattr(model, "supports_cross_encoding", False)
594
595
596


def supports_cross_encoding(
597
598
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
599
    return is_pooling_model(model) and _supports_cross_encoding(model)
600
601


602
603
def has_step_pooler(model: Union[type[object], object]) -> bool:
    """Check if the model uses step pooler."""
604
605
606
    from vllm.model_executor.layers.pooler import StepPooler

    return is_pooling_model(model) and isinstance(model.pooler, StepPooler)
607
608


609
610
611
class SupportsQuant:
    """The interface required for all models that support quantization."""

612
613
    hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None
    packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None
614
615
    quant_config: Optional[QuantizationConfig] = None

616
    def __new__(cls, *args, **kwargs) -> Self:
617
        instance = super().__new__(cls)
618
619

        # find config passed in arguments
620
621
        quant_config = cls._find_quant_config(*args, **kwargs)
        if quant_config is not None:
622
623

            # attach config to model for general use
624
            instance.quant_config = quant_config
625
626
627
628
629
630
631
632
633
634
635
636

            # apply model mappings to config for proper config-model matching
            # NOTE: `TransformersForCausalLM` is not supported due to how this
            # class defines `hf_to_vllm_mapper` as a post-init `@property`.
            # After this is fixed, get `instance.hf_to_vllm_mapper` directly
            if getattr(instance, "hf_to_vllm_mapper", None) is not None:
                instance.quant_config.apply_vllm_mapper(
                    instance.hf_to_vllm_mapper)
            if getattr(instance, "packed_modules_mapping", None) is not None:
                instance.quant_config.packed_modules_mapping.update(
                    instance.packed_modules_mapping)

637
638
639
640
        return instance

    @staticmethod
    def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
641
        """Find quant config passed through model constructor args"""
642
643
644
645
646
647
648
649
650
651
652
653
654
        from vllm.config import VllmConfig  # avoid circular import

        args_values = list(args) + list(kwargs.values())
        for arg in args_values:
            if isinstance(arg, VllmConfig):
                return arg.quant_config

            if isinstance(arg, QuantizationConfig):
                return arg

        return None


655
656
657
658
659
660
@runtime_checkable
class SupportsTranscription(Protocol):
    """The interface required for all models that support transcription."""

    supports_transcription: ClassVar[Literal[True]] = True

661
662
663
664
665
666
    supports_transcription_only: ClassVar[bool] = False
    """
    Transcription models can opt out of text generation by setting this to
    `True`.
    """

667
    @classmethod
668
    def get_generation_prompt(cls, audio: np.ndarray,
Patrick von Platen's avatar
Patrick von Platen committed
669
670
                              stt_config: SpeechToTextConfig,
                              model_config: ModelConfig, language: str,
671
672
673
674
675
                              task_type: str,
                              request_prompt: str) -> PromptType:
        """Get the prompt for the ASR model.
        The model has control over the construction, as long as it
        returns a valid PromptType."""
676
677
678
679
680
681
682
        ...

    @classmethod
    def validate_language(cls, language: str) -> bool:
        """Check if the model supports a specific ISO639_1 language."""
        ...

683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
    @classmethod
    def get_speech_to_text_config(
            cls, model_config: ModelConfig,
            task_type: Literal["transcribe",
                               "translate"]) -> SpeechToTextConfig:
        """Get the speech to text config for the ASR model."""
        ...

    @classmethod
    def get_num_audio_tokens(cls, audio_duration_s: float,
                             stt_config: SpeechToTextConfig,
                             model_config: ModelConfig) -> Optional[int]:
        """
        Map from audio duration to number of audio tokens produced by the ASR 
        model, without running a forward pass.
        This is used for estimating the amount of processing for this audio.
        """
        return None

702
703
704

@overload
def supports_transcription(
705
        model: type[object]) -> TypeIs[type[SupportsTranscription]]:
706
707
708
709
710
711
712
713
714
    ...


@overload
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
    ...


def supports_transcription(
715
716
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
717
    return getattr(model, "supports_transcription", False)
718
719
720
721
722
723
724
725
726
727


@runtime_checkable
class SupportsV0Only(Protocol):
    """Models with this interface are not compatible with V1 vLLM."""

    supports_v0_only: ClassVar[Literal[True]] = True


@overload
728
def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]:
729
730
731
732
733
734
735
736
737
    ...


@overload
def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
    ...


def supports_v0_only(
738
739
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
740
    return getattr(model, "supports_v0_only", False)