interfaces.py 19.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, MutableSequence
5
6
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
                    Union, overload, runtime_checkable)
7

8
import torch
9
from torch import Tensor
10
from typing_extensions import Self, TypeIs
11
12

from vllm.logger import init_logger
13
14
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
15
from vllm.utils import supports_kw
16

17
from .interfaces_base import is_pooling_model
18

19
if TYPE_CHECKING:
20
    from vllm.attention import AttentionMetadata
21
    from vllm.model_executor.models.utils import WeightsMapper
22
23
    from vllm.sequence import IntermediateTensors

24
25
logger = init_logger(__name__)

26
27
28
29
30
31
32
33
MultiModalEmbeddings = Union[list[Tensor], Tensor, tuple[Tensor, ...]]
"""
The output embeddings must be one of the following formats:

- A list or tuple of 2D tensors, where each tensor corresponds to
    each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
"""
34

35
36

@runtime_checkable
37
class SupportsMultiModal(Protocol):
38
    """The interface required for all multi-modal models."""
39

40
    supports_multimodal: ClassVar[Literal[True]] = True
41
    """
42
    A flag that indicates this model supports multi-modal inputs.
43
44
45
46
47

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
48

49
50
51
52
53
54
55
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        """
        Get the placeholder text for the `i`th `modality` item in the prompt.
        """
        ...

56
57
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
58
59
60
        """
        Returns multimodal embeddings generated from multimodal kwargs 
        to be merged with text embeddings.
61

62
        Note:
63
64
            The returned multimodal embeddings must be in the same order as
            the appearances of their corresponding multimodal data item in the
65
            input prompt.
66
67
68
        """
        ...

69
70
71
72
73
74
75
76
77
78
79
80
    def get_language_model(self) -> torch.nn.Module:
        """
        Returns the underlying language model used for text generation.

        This is typically the `torch.nn.Module` instance responsible for 
        processing the merged multimodal embeddings and producing hidden states

        Returns:
            torch.nn.Module: The core language model component.
        """
        ...

81
82
83
84
85
    # Only for models that support v0 chunked prefill
    # TODO(ywang96): Remove this overload once v0 is deprecated
    @overload
    def get_input_embeddings(
        self,
86
        input_ids: Tensor,
87
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
88
        attn_metadata: Optional["AttentionMetadata"] = None,
89
    ) -> Tensor:
90
91
        ...

92
    # TODO: Remove this overload once v0 is deprecated
93
    @overload
94
95
    def get_input_embeddings(
        self,
96
        input_ids: Tensor,
97
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
98
99
100
101
102
103
104
105
106
107
    ) -> Tensor:
        ...

    def get_input_embeddings(
        self,
        input_ids: Tensor,
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
        # Only necessary so that the v0 overload is valid
        # TODO: Remove attn_metadata once v0 is deprecated
        attn_metadata: Optional["AttentionMetadata"] = None,
108
    ) -> Tensor:
109
110
111
112
113
114
115
        """
        Returns the input embeddings merged from the text embeddings from 
        input_ids and the multimodal embeddings generated from multimodal 
        kwargs.
        """
        ...

116
117
118
119

# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
120
121
class _SupportsMultiModalType(Protocol):
    supports_multimodal: Literal[True]
122
123
124


@overload
125
def supports_multimodal(
126
        model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
127
128
129
130
    ...


@overload
131
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
132
133
134
    ...


135
def supports_multimodal(
136
137
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
138
    if isinstance(model, type):
139
        return isinstance(model, _SupportsMultiModalType)
140

141
    return isinstance(model, SupportsMultiModal)
142
143
144
145
146
147


@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

148
149
150
151
152
153
154
155
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
156
157
    # The `embedding_module` and `embedding_padding_modules`
    # are empty by default.
158
159
160
    embedding_modules: ClassVar[dict[str, str]] = {}
    embedding_padding_modules: ClassVar[list[str]] = []
    packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
161
162
163
164
165
166
167
168


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

169
170
171
    packed_modules_mapping: dict[str, list[str]]
    embedding_modules: dict[str, str]
    embedding_padding_modules: list[str]
172
173
174


@overload
175
def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]:
176
177
178
179
    ...


@overload
180
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
181
182
183
184
    ...


def supports_lora(
185
186
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "embedding_modules",
            "embedding_padding_modules",
        )
        missing_attrs = tuple(attr for attr in lora_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
                    "but does not set `supports_lora=True`.", model)

    return result


215
def _supports_lora(model: Union[type[object], object]) -> bool:
216
217
218
219
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
220
221


222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
246
        *,
247
        intermediate_tensors: Optional["IntermediateTensors"],
248
    ) -> Union[Tensor, "IntermediateTensors"]:
249
        """
250
251
        Accept [`IntermediateTensors`][vllm.sequence.IntermediateTensors] when
        PP rank > 0.
252

253
254
        Return [`IntermediateTensors`][vllm.sequence.IntermediateTensors] only
        for the last PP rank.
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        """
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        ...

    def forward(
        self,
275
        *,
276
        intermediate_tensors: Optional["IntermediateTensors"],
277
    ) -> Union[Tensor, "IntermediateTensors"]:
278
279
280
281
        ...


@overload
282
def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]:
283
284
285
286
287
288
289
290
291
    ...


@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
    ...


def supports_pp(
292
293
    model: Union[type[object], object],
) -> Union[bool, TypeIs[type[SupportsPP]], TypeIs[SupportsPP]]:
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
            "`intermediate_tensors` in its `forward` method", model)

    if not supports_attributes:
        pp_attrs = ("make_empty_intermediate_tensors", )
        missing_attrs = tuple(attr for attr in pp_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
                    "but does not set `supports_pp=True`.", model)

    return supports_attributes and supports_inspect


324
def _supports_pp_attributes(model: Union[type[object], object]) -> bool:
325
326
327
328
329
330
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


331
def _supports_pp_inspect(model: Union[type[object], object]) -> bool:
332
333
334
335
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

336
    return supports_kw(model_forward, "intermediate_tensors")
337
338


339
340
341
342
343
344
345
346
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
347
        for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
348
349
350
351
352
353
354
355
356
    """


@runtime_checkable
class _HasInnerStateType(Protocol):
    has_inner_state: ClassVar[Literal[True]]


@overload
357
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
358
359
360
361
    ...


@overload
362
def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
363
364
365
366
    ...


def has_inner_state(
367
368
    model: Union[type[object], object]
) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]:
369
370
371
372
    if isinstance(model, type):
        return isinstance(model, _HasInnerStateType)

    return isinstance(model, HasInnerState)
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398


@runtime_checkable
class IsAttentionFree(Protocol):
    """The interface required for all models like Mamba that lack attention,
    but do have state whose size is constant wrt the number of tokens."""

    is_attention_free: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has no attention.
        Used for block manager and attention backend selection.
        True for Mamba but not Jamba.
    """


@runtime_checkable
class _IsAttentionFreeType(Protocol):
    is_attention_free: ClassVar[Literal[True]]


@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
    ...


@overload
399
def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
400
401
402
403
    ...


def is_attention_free(
404
405
    model: Union[type[object], object]
) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
406
407
408
409
    if isinstance(model, type):
        return isinstance(model, _IsAttentionFreeType)

    return isinstance(model, IsAttentionFree)
410
411


412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
@runtime_checkable
class IsHybrid(Protocol):
    """The interface required for all models like Jamba that have both
    attention and mamba blocks, indicates that 
    hf_config has 'layers_block_type'"""

    is_hybrid: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has both mamba and attention blocks
        , also indicates that the model's hf_config has 
        'layers_block_type' """


@runtime_checkable
class _IsHybridType(Protocol):
    is_hybrid: ClassVar[Literal[True]]


@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
    ...


@overload
436
def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
437
438
439
440
    ...


def is_hybrid(
441
442
    model: Union[type[object], object]
) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]:
443
444
445
446
447
448
    if isinstance(model, type):
        return isinstance(model, _IsHybridType)

    return isinstance(model, IsHybrid)


449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
@runtime_checkable
class MixtureOfExperts(Protocol):
    """
    Check if the model is a mixture of experts (MoE) model.
    """

    expert_weights: MutableSequence[Iterable[Tensor]]
    """
    Expert weights saved in this rank.

    The first dimension is the layer, and the second dimension is different
    parameters in the layer, e.g. up/down projection weights.
    """

    num_moe_layers: int
    """Number of MoE layers in this model."""

    num_expert_groups: int
    """Number of expert groups in this model."""

    num_logical_experts: int
    """Number of logical experts in this model."""

    num_physical_experts: int
    """Number of physical experts in this model."""

    num_local_physical_experts: int
    """Number of local physical experts in this model."""

    num_routed_experts: int
    """Number of routed experts in this model."""

    num_shared_experts: int
    """Number of shared experts in this model."""

    num_redundant_experts: int
    """Number of redundant experts in this model."""

    def set_eplb_state(
        self,
        expert_load_view: Tensor,
        logical_to_physical_map: Tensor,
        logical_replica_count: Tensor,
    ) -> None:
        """
        Register the EPLB state in the MoE model.
        
        Since these are views of the actual EPLB state, any changes made by
        the EPLB algorithm are automatically reflected in the model's behavior
        without requiring additional method calls to set new states.

        You should also collect model's `expert_weights` here instead of in
        the weight loader, since after initial weight loading, further
        processing like quantization may be applied to the weights.

        Args:
            expert_load_view: A view of the expert load metrics tensor.
            logical_to_physical_map: Mapping from logical to physical experts.
            logical_replica_count: Count of replicas for each logical expert.
        """
        ...


def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
    return isinstance(model, MixtureOfExperts)


516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
@runtime_checkable
class HasNoOps(Protocol):
    has_noops: ClassVar[Literal[True]] = True


@runtime_checkable
class _HasNoOpsType(Protocol):
    has_noops: ClassVar[Literal[True]]


@overload
def has_noops(model: object) -> TypeIs[HasNoOps]:
    ...


@overload
532
def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
533
534
535
536
    ...


def has_noops(
537
538
    model: Union[type[object], object]
) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]:
539
540
541
542
543
544
    if isinstance(model, type):
        return isinstance(model, _HasNoOpsType)

    return isinstance(model, HasNoOps)


545
546
547
548
549
550
551
552
553
@runtime_checkable
class SupportsCrossEncoding(Protocol):
    """The interface required for all models that support cross encoding."""

    supports_cross_encoding: ClassVar[Literal[True]] = True


@overload
def supports_cross_encoding(
554
        model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]:
555
556
557
558
559
560
561
562
563
    ...


@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
    ...


def _supports_cross_encoding(
564
565
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
566
567
568
569
570
571
572
573

    if isinstance(model, type):
        return isinstance(model, SupportsCrossEncoding)

    return isinstance(model, SupportsCrossEncoding)


def supports_cross_encoding(
574
575
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
576
    return is_pooling_model(model) and _supports_cross_encoding(model)
577
578


579
580
581
582
583
584
def has_step_pooler(model: Union[type[object], object]) -> bool:
    """Check if the model uses step pooler."""
    return is_pooling_model(model) and any(
        type(module).__name__ == "StepPool" for module in model.modules())


585
586
587
class SupportsQuant:
    """The interface required for all models that support quantization."""

588
589
    hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None
    packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None
590
591
    quant_config: Optional[QuantizationConfig] = None

592
    def __new__(cls, *args, **kwargs) -> Self:
593
        instance = super().__new__(cls)
594
595

        # find config passed in arguments
596
597
        quant_config = cls._find_quant_config(*args, **kwargs)
        if quant_config is not None:
598
599

            # attach config to model for general use
600
            instance.quant_config = quant_config
601
602
603
604
605
606
607
608
609
610
611
612

            # apply model mappings to config for proper config-model matching
            # NOTE: `TransformersForCausalLM` is not supported due to how this
            # class defines `hf_to_vllm_mapper` as a post-init `@property`.
            # After this is fixed, get `instance.hf_to_vllm_mapper` directly
            if getattr(instance, "hf_to_vllm_mapper", None) is not None:
                instance.quant_config.apply_vllm_mapper(
                    instance.hf_to_vllm_mapper)
            if getattr(instance, "packed_modules_mapping", None) is not None:
                instance.quant_config.packed_modules_mapping.update(
                    instance.packed_modules_mapping)

613
614
615
616
        return instance

    @staticmethod
    def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
617
        """Find quant config passed through model constructor args"""
618
619
620
621
622
623
624
625
626
627
628
629
630
        from vllm.config import VllmConfig  # avoid circular import

        args_values = list(args) + list(kwargs.values())
        for arg in args_values:
            if isinstance(arg, VllmConfig):
                return arg.quant_config

            if isinstance(arg, QuantizationConfig):
                return arg

        return None


631
632
633
634
635
636
@runtime_checkable
class SupportsTranscription(Protocol):
    """The interface required for all models that support transcription."""

    supports_transcription: ClassVar[Literal[True]] = True

637
638
639
640
641
642
643
644
645
646
647
    @classmethod
    def get_decoder_prompt(cls, language: str, task_type: str,
                           prompt: str) -> str:
        """Get the decoder prompt for the ASR model."""
        ...

    @classmethod
    def validate_language(cls, language: str) -> bool:
        """Check if the model supports a specific ISO639_1 language."""
        ...

648
649
650

@overload
def supports_transcription(
651
        model: type[object]) -> TypeIs[type[SupportsTranscription]]:
652
653
654
655
656
657
658
659
660
    ...


@overload
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
    ...


def supports_transcription(
661
662
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
663
664
665
666
    if isinstance(model, type):
        return isinstance(model, SupportsTranscription)

    return isinstance(model, SupportsTranscription)
667
668
669
670
671
672
673
674
675
676


@runtime_checkable
class SupportsV0Only(Protocol):
    """Models with this interface are not compatible with V1 vLLM."""

    supports_v0_only: ClassVar[Literal[True]] = True


@overload
677
def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]:
678
679
680
681
682
683
684
685
686
    ...


@overload
def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
    ...


def supports_v0_only(
687
688
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
689
690
691
692
    if isinstance(model, type):
        return isinstance(model, SupportsV0Only)

    return isinstance(model, SupportsV0Only)