interfaces.py 34.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Callable, Iterable, Mapping, MutableSequence
5
6
7
8
9
from typing import (
    TYPE_CHECKING,
    ClassVar,
    Literal,
    Protocol,
10
    TypeAlias,
11
12
13
    overload,
    runtime_checkable,
)
14

15
import numpy as np
16
import torch
17
import torch.nn as nn
18
from torch import Tensor
19
from transformers.models.whisper.tokenization_whisper import LANGUAGES
20
from typing_extensions import Self, TypeIs
21

22
from vllm.config import ModelConfig, SpeechToTextConfig
23
from vllm.inputs import TokensPrompt
24
from vllm.inputs.data import PromptType
25
from vllm.logger import init_logger
26
from vllm.model_executor.layers.quantization import QuantizationConfig
27
from vllm.utils.func_utils import supports_kw
28

29
from .interfaces_base import VllmModel, is_pooling_model
30

31
if TYPE_CHECKING:
32
    from vllm.config import VllmConfig
33
    from vllm.model_executor.models.utils import WeightsMapper
34
    from vllm.multimodal.inputs import MultiModalFeatureSpec
35
    from vllm.multimodal.registry import _ProcessorFactories
36
    from vllm.sequence import IntermediateTensors
37
38
39
else:
    VllmConfig = object
    WeightsMapper = object
40
    MultiModalFeatureSpec = object
41
    _ProcessorFactories = object
42
    IntermediateTensors = object
43

44
45
logger = init_logger(__name__)

46
MultiModalEmbeddings: TypeAlias = list[Tensor] | Tensor | tuple[Tensor, ...]
47
48
49
50
51
52
53
"""
The output embeddings must be one of the following formats:

- A list or tuple of 2D tensors, where each tensor corresponds to
    each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
"""
54

55

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def _require_is_multimodal(is_multimodal: Tensor | None) -> Tensor:
    """
    A helper function to be used in the context of
    [vllm.model_executor.models.interfaces.SupportsMultiModal.embed_input_ids][]
    to provide a better error message.
    """
    if is_multimodal is None:
        raise ValueError(
            "`embed_input_ids` now requires `is_multimodal` arg, "
            "please update your model runner according to "
            "https://github.com/vllm-project/vllm/pull/16229."
        )

    return is_multimodal


72
@runtime_checkable
73
class SupportsMultiModal(Protocol):
74
    """The interface required for all multi-modal models."""
75

76
    supports_multimodal: ClassVar[Literal[True]] = True
77
    """
78
    A flag that indicates this model supports multi-modal inputs.
79
80
81
82
83

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
84

85
86
87
88
89
90
    supports_multimodal_raw_input_only: ClassVar[bool] = False
    """
    A flag that indicates this model supports multi-modal inputs and processes
    them in their raw form and not embeddings.
    """

91
92
93
94
95
96
    supports_encoder_tp_data: ClassVar[bool] = False
    """
    A flag that indicates whether this model supports
    `multimodal_config.mm_encoder_tp_mode="data"`.
    """

Patrick von Platen's avatar
Patrick von Platen committed
97
98
99
100
101
102
    requires_raw_input_tokens: ClassVar[bool] = False
    """
    A flag that indicates this model processes input id tokens
    in their raw form and not input embeddings.
    """

103
104
105
106
107
    _processor_factory: ClassVar[_ProcessorFactories]
    """
    Set internally by `MultiModalRegistry.register_processor`.
    """

108
    @classmethod
109
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
110
111
112
113
114
        """
        Get the placeholder text for the `i`th `modality` item in the prompt.
        """
        ...

115
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
116
        """
117
        Returns multimodal embeddings generated from multimodal kwargs
118
        to be merged with text embeddings.
119

120
        Note:
121
122
            The returned multimodal embeddings must be in the same order as
            the appearances of their corresponding multimodal data item in the
123
            input prompt.
124
        """
125
        ...
126

127
    def get_language_model(self) -> VllmModel:
128
129
130
        """
        Returns the underlying language model used for text generation.

131
        This is typically the `torch.nn.Module` instance responsible for
132
133
134
135
136
137
138
        processing the merged multimodal embeddings and producing hidden states

        Returns:
            torch.nn.Module: The core language model component.
        """
        ...

139
140
141
142
143
144
145
146
    @classmethod
    def get_language_model_spec(cls) -> tuple[nn.Module | None, str | None]:
        """
        Return the language model spec:
        (language model class, language model attr)
        """
        return None, None

147
    @overload
148
    def embed_input_ids(self, input_ids: Tensor) -> Tensor: ...
149
150

    @overload
151
    def embed_input_ids(
152
153
154
155
156
157
        self,
        input_ids: Tensor,
        multimodal_embeddings: MultiModalEmbeddings,
        *,
        is_multimodal: torch.Tensor,
        handle_oov_mm_token: bool = False,
158
    ) -> Tensor: ...
159

160
    def _embed_text_input_ids(
161
162
        self,
        input_ids: Tensor,
163
        embed_input_ids: Callable[[Tensor], Tensor],
164
        *,
165
        is_multimodal: Tensor | None,
166
167
168
169
        handle_oov_mm_token: bool,
    ) -> Tensor:
        if handle_oov_mm_token and is_multimodal is not None:
            is_text = ~is_multimodal
170
            text_embeds = embed_input_ids(input_ids[is_text])
171
172
173
174
175
176
177

            return torch.empty(
                (input_ids.shape[0], text_embeds.shape[1]),
                dtype=text_embeds.dtype,
                device=text_embeds.device,
            ).masked_scatter_(is_text.unsqueeze_(-1), text_embeds)

178
        return embed_input_ids(input_ids)
179

180
    def embed_input_ids(
181
        self,
182
        input_ids: Tensor,
183
        multimodal_embeddings: MultiModalEmbeddings | None = None,
184
        *,
185
        is_multimodal: Tensor | None = None,
186
        handle_oov_mm_token: bool = False,
187
    ) -> Tensor:
188
        """
189
190
191
192
193
194
195
        Apply token embeddings to `input_ids`.

        If `multimodal_embeddings` is passed, scatter them into
        `input_ids` according to the mask `is_multimodal`.

        In case the multi-modal token IDs exceed the vocabulary size of
        the language model, you can set `handle_oov_mm_token=False`
196
        to avoid calling the language model's `embed_input_ids` method
197
198
        on those tokens. Note however that doing so increases memory usage
        as an additional buffer is needed to hold the input embeddings.
199
        """
200
201
        from .utils import _merge_multimodal_embeddings

202
        inputs_embeds = self._embed_text_input_ids(
203
            input_ids,
204
            self.get_language_model().embed_input_ids,
205
206
207
208
209
210
211
212
213
214
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
215
            is_multimodal=_require_is_multimodal(is_multimodal),
216
        )
217

218

219
220
221
222
223
224
@runtime_checkable
class SupportsMultiModalPruning(Protocol):
    """The interface required for models that support returning both input
    embeddings and positions. Model may require custom positions for dynamic
    pruning of multimodal embeddings.
    """
225

226
227
228
    supports_multimodal_pruning: ClassVar[Literal[True]] = True

    def recompute_mrope_positions(
229
230
231
232
233
        self,
        input_ids: list[int],
        multimodal_embeddings: MultiModalEmbeddings,
        mrope_positions: torch.LongTensor,
        num_computed_tokens: int,
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
    ) -> tuple[MultiModalEmbeddings, Tensor, int]:
        """
        Update part of input mrope positions (starting with
        num_computed_tokens index). Original mrope_positions are computed
        for unpruned sequence and becomes incorrect once pruning occurs,
        so once we prune media tokens we should reflect this in the
        mrope_positions before we feed it to LLM.

        Args:
            input_ids: (N,) All input tokens of the prompt containing
                entire sequence.
            multimodal_embeddings: Tuple of multimodal embeddings that
                fits into the prefill chunk that is being processed.
            mrope_positions: Existing mrope positions (3, N) for entire
                sequence
            num_computed_tokens: A number of computed tokens so far.

        Returns:
            Tuple of (multimodal_embeddings, mrope_positions,
                mrope_position_delta).
        """
        ...


258
@overload
259
def supports_multimodal(model: type[object]) -> TypeIs[type[SupportsMultiModal]]: ...
260
261
262


@overload
263
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]: ...
264
265


266
def supports_multimodal(
267
268
    model: type[object] | object,
) -> TypeIs[type[SupportsMultiModal]] | TypeIs[SupportsMultiModal]:
269
    return getattr(model, "supports_multimodal", False)
270
271


272
def supports_multimodal_raw_input_only(model: type[object] | object) -> bool:
273
    return getattr(model, "supports_multimodal_raw_input_only", False)
274

275

Patrick von Platen's avatar
Patrick von Platen committed
276
277
278
279
def requires_raw_input_tokens(model: type[object] | object) -> bool:
    return getattr(model, "requires_raw_input_tokens", False)


280
def supports_multimodal_encoder_tp_data(model: type[object] | object) -> bool:
281
    return getattr(model, "supports_encoder_tp_data", False)
282
283


284
285
286
287
def supports_mm_encoder_only(model: type[object] | object) -> bool:
    return getattr(model, "is_mm_encoder_only_model", False)


288
289
@overload
def supports_multimodal_pruning(
290
291
    model: type[object],
) -> TypeIs[type[SupportsMultiModalPruning]]: ...
292
293
294


@overload
295
def supports_multimodal_pruning(model: object) -> TypeIs[SupportsMultiModalPruning]: ...
296
297
298


def supports_multimodal_pruning(
299
300
    model: type[object] | object,
) -> TypeIs[type[SupportsMultiModalPruning]] | TypeIs[SupportsMultiModalPruning]:
301
302
303
    return getattr(model, "supports_multimodal_pruning", False)


304
305
306
307
308
309
310
311
312
313
314
315
316
317
@runtime_checkable
class SupportsScoreTemplate(Protocol):
    """The interface required for all models that support score template."""

    supports_score_template: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports score template.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    @classmethod
318
    def get_score_template(cls, query: str, document: str) -> str | None:
319
320
        """
        Generate a full prompt by populating the score template with query and document content.
321
        """  # noqa: E501
322
323
324
325
326
327
328
329
330
331
332
333
        ...

    @classmethod
    def post_process_tokens(cls, prompt: TokensPrompt) -> None:
        """
        Perform architecture-specific manipulations on the input tokens.
        """
        ...


@overload
def supports_score_template(
334
335
    model: type[object],
) -> TypeIs[type[SupportsScoreTemplate]]: ...
336
337
338


@overload
339
def supports_score_template(model: object) -> TypeIs[SupportsScoreTemplate]: ...
340
341
342


def supports_score_template(
343
344
    model: type[object] | object,
) -> TypeIs[type[SupportsScoreTemplate]] | TypeIs[SupportsScoreTemplate]:
345
    return getattr(model, "supports_score_template", False)
346
347


348
349
350
351
@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

352
353
354
355
356
357
358
359
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
360
    is_3d_moe_weight: ClassVar[bool] = False
361
362
    # The `embedding_module` and `embedding_padding_modules`
    # are empty by default.
363
    embedding_modules: ClassVar[dict[str, str]] = {}
364
    packed_modules_mapping: dict[str, list[str]] = {}
365
366
367
368
369
370
371
372


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

373
374
    packed_modules_mapping: dict[str, list[str]]
    embedding_modules: dict[str, str]
375
376
377


@overload
378
def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]: ...
379
380
381


@overload
382
def supports_lora(model: object) -> TypeIs[SupportsLoRA]: ...
383
384
385


def supports_lora(
386
387
    model: type[object] | object,
) -> TypeIs[type[SupportsLoRA]] | TypeIs[SupportsLoRA]:
388
389
390
391
392
393
394
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "embedding_modules",
        )
395
        missing_attrs = tuple(attr for attr in lora_attrs if not hasattr(model, attr))
396
397
398
399
400
401
402
403
404
405
406
407
408

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
409
410
411
                    "but does not set `supports_lora=True`.",
                    model,
                )
412
413
414
415

    return result


416
def _supports_lora(model: type[object] | object) -> bool:
417
418
419
420
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
421
422


423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
441
    ) -> IntermediateTensors:
442
443
444
445
446
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
447
        *,
448
449
        intermediate_tensors: IntermediateTensors | None,
    ) -> IntermediateTensors | None:
450
        """
451
452
        Accept [`IntermediateTensors`][vllm.sequence.IntermediateTensors] when
        PP rank > 0.
453

454
455
        Return [`IntermediateTensors`][vllm.sequence.IntermediateTensors] only
        for the last PP rank.
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
        """
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
471
    ) -> IntermediateTensors: ...
472
473
474

    def forward(
        self,
475
        *,
476
477
        intermediate_tensors: IntermediateTensors | None,
    ) -> Tensor | IntermediateTensors: ...
478
479
480


@overload
481
def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]: ...
482
483
484


@overload
485
def supports_pp(model: object) -> TypeIs[SupportsPP]: ...
486
487
488


def supports_pp(
489
490
    model: type[object] | object,
) -> bool | TypeIs[type[SupportsPP]] | TypeIs[SupportsPP]:
491
492
493
494
495
496
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
497
498
499
            "`intermediate_tensors` in its `forward` method",
            model,
        )
500
501

    if not supports_attributes:
502
503
        pp_attrs = ("make_empty_intermediate_tensors",)
        missing_attrs = tuple(attr for attr in pp_attrs if not hasattr(model, attr))
504
505
506
507
508
509
510
511
512
513
514
515
516

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
517
518
519
                    "but does not set `supports_pp=True`.",
                    model,
                )
520
521
522
523

    return supports_attributes and supports_inspect


524
def _supports_pp_attributes(model: type[object] | object) -> bool:
525
526
527
528
529
530
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


531
def _supports_pp_inspect(model: type[object] | object) -> bool:
532
533
534
535
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

536
    return supports_kw(model_forward, "intermediate_tensors")
537
538


539
540
541
542
543
544
545
546
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
547
        for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
548
549
550
551
    """


@overload
552
def has_inner_state(model: object) -> TypeIs[HasInnerState]: ...
553
554
555


@overload
556
def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]: ...
557
558
559


def has_inner_state(
560
561
    model: type[object] | object,
) -> TypeIs[type[HasInnerState]] | TypeIs[HasInnerState]:
562
    return getattr(model, "has_inner_state", False)
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578


@runtime_checkable
class IsAttentionFree(Protocol):
    """The interface required for all models like Mamba that lack attention,
    but do have state whose size is constant wrt the number of tokens."""

    is_attention_free: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has no attention.
        Used for block manager and attention backend selection.
        True for Mamba but not Jamba.
    """


@overload
579
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]: ...
580
581
582


@overload
583
def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]: ...
584
585
586


def is_attention_free(
587
588
    model: type[object] | object,
) -> TypeIs[type[IsAttentionFree]] | TypeIs[IsAttentionFree]:
589
    return getattr(model, "is_attention_free", False)
590
591


592
593
594
@runtime_checkable
class IsHybrid(Protocol):
    """The interface required for all models like Jamba that have both
595
    attention and mamba blocks, indicates that
596
597
598
599
600
601
602
603
    hf_config has 'layers_block_type'"""

    is_hybrid: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has both mamba and attention blocks
        , also indicates that the model's hf_config has 
        'layers_block_type' """

604
605
606
    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
607
        vllm_config: VllmConfig,
608
609
610
611
612
613
614
615
616
617
618
619
620
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        ...

621
622

@overload
623
def is_hybrid(model: object) -> TypeIs[IsHybrid]: ...
624
625
626


@overload
627
def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]: ...
628
629
630


def is_hybrid(
631
632
    model: type[object] | object,
) -> TypeIs[type[IsHybrid]] | TypeIs[IsHybrid]:
633
    return getattr(model, "is_hybrid", False)
634
635


636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
@runtime_checkable
class MixtureOfExperts(Protocol):
    """
    Check if the model is a mixture of experts (MoE) model.
    """

    expert_weights: MutableSequence[Iterable[Tensor]]
    """
    Expert weights saved in this rank.

    The first dimension is the layer, and the second dimension is different
    parameters in the layer, e.g. up/down projection weights.
    """

    num_moe_layers: int
    """Number of MoE layers in this model."""

    num_expert_groups: int
    """Number of expert groups in this model."""

    num_logical_experts: int
    """Number of logical experts in this model."""

    num_physical_experts: int
    """Number of physical experts in this model."""

    num_local_physical_experts: int
    """Number of local physical experts in this model."""

    num_routed_experts: int
    """Number of routed experts in this model."""

    num_shared_experts: int
    """Number of shared experts in this model."""

    num_redundant_experts: int
    """Number of redundant experts in this model."""

674
675
676
    moe_layers: Iterable[nn.Module]
    """List of MoE layers in this model."""

677
678
679
680
681
682
683
684
    def set_eplb_state(
        self,
        expert_load_view: Tensor,
        logical_to_physical_map: Tensor,
        logical_replica_count: Tensor,
    ) -> None:
        """
        Register the EPLB state in the MoE model.
685

686
687
688
689
690
691
692
693
694
695
696
697
698
        Since these are views of the actual EPLB state, any changes made by
        the EPLB algorithm are automatically reflected in the model's behavior
        without requiring additional method calls to set new states.

        You should also collect model's `expert_weights` here instead of in
        the weight loader, since after initial weight loading, further
        processing like quantization may be applied to the weights.

        Args:
            expert_load_view: A view of the expert load metrics tensor.
            logical_to_physical_map: Mapping from logical to physical experts.
            logical_replica_count: Count of replicas for each logical expert.
        """
699
700
701
702
703
704
705
706
707
        for layer_idx, layer in enumerate(self.moe_layers):
            # Register the expert weights.
            self.expert_weights.append(layer.get_expert_weights())
            layer.set_eplb_state(
                moe_layer_idx=layer_idx,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )
708

709
710
711
712
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
713
    ) -> None: ...
714

715
716

def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
717
718
719
    return (
        isinstance(model, MixtureOfExperts) and getattr(model, "num_moe_layers", 0) > 0
    )
720
721


722
723
724
725
726
727
@runtime_checkable
class HasNoOps(Protocol):
    has_noops: ClassVar[Literal[True]] = True


@overload
728
def has_noops(model: object) -> TypeIs[HasNoOps]: ...
729
730
731


@overload
732
def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]: ...
733
734
735


def has_noops(
736
737
    model: type[object] | object,
) -> TypeIs[type[HasNoOps]] | TypeIs[HasNoOps]:
738
    return getattr(model, "has_noops", False)
739
740


741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
@runtime_checkable
class SupportsMambaPrefixCaching(Protocol):
    """The interface for models whose mamba layers support prefix caching.

    This is currently experimental.
    """

    supports_mamba_prefix_caching: ClassVar[Literal[True]] = True


@overload
def supports_mamba_prefix_caching(
    model: object,
) -> TypeIs[SupportsMambaPrefixCaching]: ...


@overload
def supports_mamba_prefix_caching(
    model: type[object],
) -> TypeIs[type[SupportsMambaPrefixCaching]]: ...


def supports_mamba_prefix_caching(
    model: type[object] | object,
) -> TypeIs[type[SupportsMambaPrefixCaching]] | TypeIs[SupportsMambaPrefixCaching]:
    return getattr(model, "supports_mamba_prefix_caching", False)


769
770
771
772
773
774
775
776
777
@runtime_checkable
class SupportsCrossEncoding(Protocol):
    """The interface required for all models that support cross encoding."""

    supports_cross_encoding: ClassVar[Literal[True]] = True


@overload
def supports_cross_encoding(
778
779
    model: type[object],
) -> TypeIs[type[SupportsCrossEncoding]]: ...
780
781
782


@overload
783
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: ...
784
785
786


def _supports_cross_encoding(
787
788
    model: type[object] | object,
) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]:
789
    return getattr(model, "supports_cross_encoding", False)
790
791
792


def supports_cross_encoding(
793
794
    model: type[object] | object,
) -> TypeIs[type[SupportsCrossEncoding]] | TypeIs[SupportsCrossEncoding]:
795
    return is_pooling_model(model) and _supports_cross_encoding(model)
796
797


798
799
800
class SupportsQuant:
    """The interface required for all models that support quantization."""

801
802
803
    hf_to_vllm_mapper: ClassVar[WeightsMapper | None] = None
    packed_modules_mapping: ClassVar[dict[str, list[str]] | None] = None
    quant_config: QuantizationConfig | None = None
804

805
    def __new__(cls, *args, **kwargs) -> Self:
806
        instance = super().__new__(cls)
807
808

        # find config passed in arguments
809
810
        quant_config = cls._find_quant_config(*args, **kwargs)
        if quant_config is not None:
811
            # attach config to model for general use
812
            instance.quant_config = quant_config
813
814

            # apply model mappings to config for proper config-model matching
815
816
817
            if (hf_to_vllm_mapper := instance.hf_to_vllm_mapper) is not None:
                instance.quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
            if instance.packed_modules_mapping is not None:
818
                instance.quant_config.packed_modules_mapping.update(
819
820
                    instance.packed_modules_mapping
                )
821

822
823
824
        return instance

    @staticmethod
825
    def _find_quant_config(*args, **kwargs) -> QuantizationConfig | None:
826
        """Find quant config passed through model constructor args"""
827
828
829
830
831
832
833
834
835
836
837
838
839
        from vllm.config import VllmConfig  # avoid circular import

        args_values = list(args) + list(kwargs.values())
        for arg in args_values:
            if isinstance(arg, VllmConfig):
                return arg.quant_config

            if isinstance(arg, QuantizationConfig):
                return arg

        return None


840
841
842
@runtime_checkable
class SupportsTranscription(Protocol):
    """The interface required for all models that support transcription."""
843

844
845
    # Mapping from ISO639_1 language codes: language names
    supported_languages: ClassVar[Mapping[str, str]]
846
847
848

    supports_transcription: ClassVar[Literal[True]] = True

849
850
851
852
853
    supports_transcription_only: ClassVar[bool] = False
    """
    Transcription models can opt out of text generation by setting this to
    `True`.
    """
854
855
856
857
    supports_segment_timestamp: ClassVar[bool] = False
    """
    Enables the segment timestamp option for supported models by setting this to `True`.
    """
858

859
860
861
862
863
864
865
866
867
    def __init_subclass__(cls, **kwargs):
        super().__init_subclass__(**kwargs)
        # language codes in supported_languages
        # that don't exist in the full language map
        invalid = set(cls.supported_languages) - set(LANGUAGES.keys())
        if invalid:
            raise ValueError(
                f"{cls.__name__}.supported_languages contains invalid "
                f"language codes: {sorted(invalid)}\n. "
868
869
                f"Valid choices are: {sorted(LANGUAGES.keys())}"
            )
870

871
    @classmethod
872
873
874
875
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
        stt_config: SpeechToTextConfig,
876
        model_config: ModelConfig,
877
        language: str | None,
878
879
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
880
        to_language: str | None,
881
    ) -> PromptType:
882
883
884
        """Get the prompt for the ASR model.
        The model has control over the construction, as long as it
        returns a valid PromptType."""
885
886
887
        ...

    @classmethod
888
889
    def get_other_languages(cls) -> Mapping[str, str]:
        # other possible language codes from the whisper map
890
        return {k: v for k, v in LANGUAGES.items() if k not in cls.supported_languages}
891
892

    @classmethod
893
    def validate_language(cls, language: str | None) -> str | None:
894
        """
895
896
897
        Ensure the language specified in the transcription request
        is a valid ISO 639-1 language code. If the request language is
        valid, but not natively supported by the model, trigger a
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
        warning (but not an exception).
        """
        if language is None or language in cls.supported_languages:
            return language
        elif language in cls.get_other_languages():
            logger.warning(
                "Language %r is not natively supported by %s; "
                "results may be less accurate. Supported languages: %r",
                language,
                cls.__name__,
                list(cls.supported_languages.keys()),
            )
            return language
        else:
            raise ValueError(
                f"Unsupported language: {language!r}.  Must be one of "
914
915
                f"{list(cls.supported_languages.keys())}."
            )
916

917
918
    @classmethod
    def get_speech_to_text_config(
919
        cls, model_config: ModelConfig, task_type: Literal["transcribe", "translate"]
920
    ) -> SpeechToTextConfig:
921
922
923
924
        """Get the speech to text config for the ASR model."""
        ...

    @classmethod
925
926
927
928
    def get_num_audio_tokens(
        cls,
        audio_duration_s: float,
        stt_config: SpeechToTextConfig,
929
        model_config: ModelConfig,
930
    ) -> int | None:
931
        """
932
        Map from audio duration to number of audio tokens produced by the ASR
933
934
935
936
937
        model, without running a forward pass.
        This is used for estimating the amount of processing for this audio.
        """
        return None

938
939
940

@overload
def supports_transcription(
941
942
    model: type[object],
) -> TypeIs[type[SupportsTranscription]]: ...
943
944
945


@overload
946
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]: ...
947
948
949


def supports_transcription(
950
951
    model: type[object] | object,
) -> TypeIs[type[SupportsTranscription]] | TypeIs[SupportsTranscription]:
952
    return getattr(model, "supports_transcription", False)
953
954


955
@runtime_checkable
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
class SupportsEagleBase(Protocol):
    """Base interface for models that support EAGLE-based speculative decoding."""

    has_own_lm_head: bool = False
    """
    A flag that indicates this model has trained its own lm_head.
    """

    has_own_embed_tokens: bool = False
    """
    A flag that indicates this model has trained its own input embeddings.
    """


@overload
def supports_any_eagle(model: type[object]) -> TypeIs[type[SupportsEagleBase]]: ...


@overload
def supports_any_eagle(model: object) -> TypeIs[SupportsEagleBase]: ...


def supports_any_eagle(
    model: type[object] | object,
) -> TypeIs[type[SupportsEagleBase]] | TypeIs[SupportsEagleBase]:
    """Check if model supports any EAGLE variant (1, 2, or 3)."""
    return supports_eagle(model) or supports_eagle3(model)


@runtime_checkable
class SupportsEagle(SupportsEagleBase, Protocol):
    """The interface required for models that support
    EAGLE-1 and EAGLE-2 speculative decoding."""

    supports_eagle: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports EAGLE-1 and EAGLE-2 
    speculative decoding.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """


@overload
def supports_eagle(model: type[object]) -> TypeIs[type[SupportsEagle]]: ...


@overload
def supports_eagle(model: object) -> TypeIs[SupportsEagle]: ...


def supports_eagle(
    model: type[object] | object,
) -> TypeIs[type[SupportsEagle]] | TypeIs[SupportsEagle]:
    return isinstance(model, SupportsEagle)


@runtime_checkable
class SupportsEagle3(SupportsEagleBase, Protocol):
1017
    """The interface required for models that support
1018
    EAGLE-3 speculative decoding."""
1019
1020
1021

    supports_eagle3: ClassVar[Literal[True]] = True
    """
1022
    A flag that indicates this model supports EAGLE-3 
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
    speculative decoding.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        """
        Set which layers should output auxiliary
1033
        hidden states for EAGLE-3.
1034

1035
1036
        Args:
            layers: Tuple of layer indices that should output auxiliary
1037
                hidden states.
1038
1039
1040
1041
1042
1043
        """
        ...

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        """
        Get the layer indices that should output auxiliary hidden states
1044
        for EAGLE-3.
1045

1046
1047
1048
1049
1050
1051
1052
        Returns:
            Tuple of layer indices for auxiliary hidden state outputs.
        """
        ...


@overload
1053
def supports_eagle3(model: type[object]) -> TypeIs[type[SupportsEagle3]]: ...
1054
1055
1056


@overload
1057
def supports_eagle3(model: object) -> TypeIs[SupportsEagle3]: ...
1058
1059
1060


def supports_eagle3(
1061
1062
    model: type[object] | object,
) -> TypeIs[type[SupportsEagle3]] | TypeIs[SupportsEagle3]:
1063
    return isinstance(model, SupportsEagle3)
1064
1065
1066
1067
1068
1069
1070
1071
1072


@runtime_checkable
class SupportsMRoPE(Protocol):
    """The interface required for all models that support M-RoPE."""

    supports_mrope: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports M-RoPE.
1073

1074
1075
1076
1077
1078
1079
1080
1081
    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
1082
        mm_features: list["MultiModalFeatureSpec"],
1083
1084
1085
    ) -> tuple[torch.Tensor, int]:
        """
        Get M-RoPE input positions and delta value for this specific model.
1086

1087
1088
        This method should be implemented by each model that supports M-RoPE
        to provide model-specific logic for computing input positions.
1089

1090
1091
        Args:
            input_tokens: List of input token IDs
1092
            mm_features: Information about each multi-modal data item
1093

1094
        Returns:
1095
1096
            Tuple of `(llm_positions, mrope_position_delta)`
            - llm_positions: Tensor of shape `[3, num_tokens]` with T/H/W positions
1097
1098
1099
1100
1101
1102
            - mrope_position_delta: Delta for position calculations
        """
        ...


@overload
1103
def supports_mrope(model: type[object]) -> TypeIs[type[SupportsMRoPE]]: ...
1104
1105
1106


@overload
1107
def supports_mrope(model: object) -> TypeIs[SupportsMRoPE]: ...
1108
1109
1110


def supports_mrope(
1111
1112
    model: type[object] | object,
) -> TypeIs[type[SupportsMRoPE]] | TypeIs[SupportsMRoPE]:
1113
    return isinstance(model, SupportsMRoPE)
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162


@runtime_checkable
class SupportsXDRoPE(Protocol):
    """The interface required for all models that support XD-RoPE."""

    supports_xdrope: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports XD-RoPE.

    Note:
        There is no need to redefine this flag if this class is in the
        XDRope of your model class.
    """

    def get_xdrope_input_positions(
        self,
        input_tokens: list[int],
        mm_features: list["MultiModalFeatureSpec"],
    ) -> torch.Tensor:
        """
        Get XD-RoPE input positions and delta value for this specific model.

        This method should be implemented by each model that supports XD-RoPE
        to provide model-specific logic for computing input positions.

        Args:
            input_tokens: List of input token IDs
            mm_features: Information about each multi-modal data item

        Returns:
            llm_positions: Tensor of shape `[xdrope_dim, num_tokens]` with
            4D(P/W/H/T) or 3D(W/H/T) positions.
        """
        ...


@overload
def supports_xdrope(model: type[object]) -> TypeIs[type[SupportsXDRoPE]]: ...


@overload
def supports_xdrope(model: object) -> TypeIs[SupportsXDRoPE]: ...


def supports_xdrope(
    model: type[object] | object,
) -> TypeIs[type[SupportsXDRoPE]] | TypeIs[SupportsXDRoPE]:
    return isinstance(model, SupportsXDRoPE)