interfaces.py 19.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from collections.abc import Iterable, MutableSequence
5
6
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
                    Union, overload, runtime_checkable)
7

8
import torch
9
from torch import Tensor
10
from typing_extensions import Self, TypeIs
11
12

from vllm.logger import init_logger
13
14
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
15
from vllm.utils import supports_kw
16

17
from .interfaces_base import is_pooling_model
18

19
if TYPE_CHECKING:
20
    from vllm.attention import AttentionMetadata
21
    from vllm.model_executor.models.utils import WeightsMapper
22
23
    from vllm.sequence import IntermediateTensors

24
25
logger = init_logger(__name__)

26
27
28
29
30
31
32
33
MultiModalEmbeddings = Union[list[Tensor], Tensor, tuple[Tensor, ...]]
"""
The output embeddings must be one of the following formats:

- A list or tuple of 2D tensors, where each tensor corresponds to
    each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
"""
34

35
36

@runtime_checkable
37
class SupportsMultiModal(Protocol):
38
    """The interface required for all multi-modal models."""
39

40
    supports_multimodal: ClassVar[Literal[True]] = True
41
    """
42
    A flag that indicates this model supports multi-modal inputs.
43
44
45
46
47

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
48

49
50
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
51
52
53
        """
        Returns multimodal embeddings generated from multimodal kwargs 
        to be merged with text embeddings.
54

55
        Note:
56
57
            The returned multimodal embeddings must be in the same order as
            the appearances of their corresponding multimodal data item in the
58
            input prompt.
59
60
61
        """
        ...

62
63
64
65
66
67
68
69
70
71
72
73
    def get_language_model(self) -> torch.nn.Module:
        """
        Returns the underlying language model used for text generation.

        This is typically the `torch.nn.Module` instance responsible for 
        processing the merged multimodal embeddings and producing hidden states

        Returns:
            torch.nn.Module: The core language model component.
        """
        ...

74
75
76
77
78
    # Only for models that support v0 chunked prefill
    # TODO(ywang96): Remove this overload once v0 is deprecated
    @overload
    def get_input_embeddings(
        self,
79
        input_ids: Tensor,
80
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
81
        attn_metadata: Optional["AttentionMetadata"] = None,
82
    ) -> Tensor:
83
84
        ...

85
    @overload
86
87
    def get_input_embeddings(
        self,
88
        input_ids: Tensor,
89
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
90
    ) -> Tensor:
91
92
93
94
95
96
97
        """
        Returns the input embeddings merged from the text embeddings from 
        input_ids and the multimodal embeddings generated from multimodal 
        kwargs.
        """
        ...

98
99
100
101

# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
102
103
class _SupportsMultiModalType(Protocol):
    supports_multimodal: Literal[True]
104
105
106


@overload
107
def supports_multimodal(
108
        model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
109
110
111
112
    ...


@overload
113
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
114
115
116
    ...


117
def supports_multimodal(
118
119
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
120
    if isinstance(model, type):
121
        return isinstance(model, _SupportsMultiModalType)
122

123
    return isinstance(model, SupportsMultiModal)
124
125
126
127
128
129


@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

130
131
132
133
134
135
136
137
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
138
139
    # The `embedding_module` and `embedding_padding_modules`
    # are empty by default.
140
141
142
    embedding_modules: ClassVar[dict[str, str]] = {}
    embedding_padding_modules: ClassVar[list[str]] = []
    packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
143
144
145
146
147
148
149
150


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

151
152
153
    packed_modules_mapping: dict[str, list[str]]
    embedding_modules: dict[str, str]
    embedding_padding_modules: list[str]
154
155
156


@overload
157
def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]:
158
159
160
161
    ...


@overload
162
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
163
164
165
166
    ...


def supports_lora(
167
168
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "embedding_modules",
            "embedding_padding_modules",
        )
        missing_attrs = tuple(attr for attr in lora_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
                    "but does not set `supports_lora=True`.", model)

    return result


197
def _supports_lora(model: Union[type[object], object]) -> bool:
198
199
200
201
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
202
203


204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
228
        *,
229
        intermediate_tensors: Optional["IntermediateTensors"],
230
    ) -> Union[Tensor, "IntermediateTensors"]:
231
        """
232
233
        Accept [`IntermediateTensors`][vllm.sequence.IntermediateTensors] when
        PP rank > 0.
234

235
236
        Return [`IntermediateTensors`][vllm.sequence.IntermediateTensors] only
        for the last PP rank.
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        """
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        ...

    def forward(
        self,
257
        *,
258
        intermediate_tensors: Optional["IntermediateTensors"],
259
    ) -> Union[Tensor, "IntermediateTensors"]:
260
261
262
263
        ...


@overload
264
def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]:
265
266
267
268
269
270
271
272
273
    ...


@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
    ...


def supports_pp(
274
275
    model: Union[type[object], object],
) -> Union[bool, TypeIs[type[SupportsPP]], TypeIs[SupportsPP]]:
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
            "`intermediate_tensors` in its `forward` method", model)

    if not supports_attributes:
        pp_attrs = ("make_empty_intermediate_tensors", )
        missing_attrs = tuple(attr for attr in pp_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
                    "but does not set `supports_pp=True`.", model)

    return supports_attributes and supports_inspect


306
def _supports_pp_attributes(model: Union[type[object], object]) -> bool:
307
308
309
310
311
312
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


313
def _supports_pp_inspect(model: Union[type[object], object]) -> bool:
314
315
316
317
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

318
    return supports_kw(model_forward, "intermediate_tensors")
319
320


321
322
323
324
325
326
327
328
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
329
        for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
330
331
332
333
334
335
336
337
338
    """


@runtime_checkable
class _HasInnerStateType(Protocol):
    has_inner_state: ClassVar[Literal[True]]


@overload
339
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
340
341
342
343
    ...


@overload
344
def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
345
346
347
348
    ...


def has_inner_state(
349
350
    model: Union[type[object], object]
) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]:
351
352
353
354
    if isinstance(model, type):
        return isinstance(model, _HasInnerStateType)

    return isinstance(model, HasInnerState)
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380


@runtime_checkable
class IsAttentionFree(Protocol):
    """The interface required for all models like Mamba that lack attention,
    but do have state whose size is constant wrt the number of tokens."""

    is_attention_free: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has no attention.
        Used for block manager and attention backend selection.
        True for Mamba but not Jamba.
    """


@runtime_checkable
class _IsAttentionFreeType(Protocol):
    is_attention_free: ClassVar[Literal[True]]


@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
    ...


@overload
381
def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
382
383
384
385
    ...


def is_attention_free(
386
387
    model: Union[type[object], object]
) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
388
389
390
391
    if isinstance(model, type):
        return isinstance(model, _IsAttentionFreeType)

    return isinstance(model, IsAttentionFree)
392
393


394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
@runtime_checkable
class IsHybrid(Protocol):
    """The interface required for all models like Jamba that have both
    attention and mamba blocks, indicates that 
    hf_config has 'layers_block_type'"""

    is_hybrid: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has both mamba and attention blocks
        , also indicates that the model's hf_config has 
        'layers_block_type' """


@runtime_checkable
class _IsHybridType(Protocol):
    is_hybrid: ClassVar[Literal[True]]


@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
    ...


@overload
418
def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
419
420
421
422
    ...


def is_hybrid(
423
424
    model: Union[type[object], object]
) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]:
425
426
427
428
429
430
    if isinstance(model, type):
        return isinstance(model, _IsHybridType)

    return isinstance(model, IsHybrid)


431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
@runtime_checkable
class MixtureOfExperts(Protocol):
    """
    Check if the model is a mixture of experts (MoE) model.
    """

    expert_weights: MutableSequence[Iterable[Tensor]]
    """
    Expert weights saved in this rank.

    The first dimension is the layer, and the second dimension is different
    parameters in the layer, e.g. up/down projection weights.
    """

    num_moe_layers: int
    """Number of MoE layers in this model."""

    num_expert_groups: int
    """Number of expert groups in this model."""

    num_logical_experts: int
    """Number of logical experts in this model."""

    num_physical_experts: int
    """Number of physical experts in this model."""

    num_local_physical_experts: int
    """Number of local physical experts in this model."""

    num_routed_experts: int
    """Number of routed experts in this model."""

    num_shared_experts: int
    """Number of shared experts in this model."""

    num_redundant_experts: int
    """Number of redundant experts in this model."""

    def set_eplb_state(
        self,
        expert_load_view: Tensor,
        logical_to_physical_map: Tensor,
        logical_replica_count: Tensor,
    ) -> None:
        """
        Register the EPLB state in the MoE model.
        
        Since these are views of the actual EPLB state, any changes made by
        the EPLB algorithm are automatically reflected in the model's behavior
        without requiring additional method calls to set new states.

        You should also collect model's `expert_weights` here instead of in
        the weight loader, since after initial weight loading, further
        processing like quantization may be applied to the weights.

        Args:
            expert_load_view: A view of the expert load metrics tensor.
            logical_to_physical_map: Mapping from logical to physical experts.
            logical_replica_count: Count of replicas for each logical expert.
        """
        ...


def is_mixture_of_experts(model: object) -> TypeIs[MixtureOfExperts]:
    return isinstance(model, MixtureOfExperts)


498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
@runtime_checkable
class HasNoOps(Protocol):
    has_noops: ClassVar[Literal[True]] = True


@runtime_checkable
class _HasNoOpsType(Protocol):
    has_noops: ClassVar[Literal[True]]


@overload
def has_noops(model: object) -> TypeIs[HasNoOps]:
    ...


@overload
514
def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
515
516
517
518
    ...


def has_noops(
519
520
    model: Union[type[object], object]
) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]:
521
522
523
524
525
526
    if isinstance(model, type):
        return isinstance(model, _HasNoOpsType)

    return isinstance(model, HasNoOps)


527
528
529
530
531
532
533
534
535
@runtime_checkable
class SupportsCrossEncoding(Protocol):
    """The interface required for all models that support cross encoding."""

    supports_cross_encoding: ClassVar[Literal[True]] = True


@overload
def supports_cross_encoding(
536
        model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]:
537
538
539
540
541
542
543
544
545
    ...


@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
    ...


def _supports_cross_encoding(
546
547
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
548
549
550
551
552
553
554
555

    if isinstance(model, type):
        return isinstance(model, SupportsCrossEncoding)

    return isinstance(model, SupportsCrossEncoding)


def supports_cross_encoding(
556
557
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
558
    return is_pooling_model(model) and _supports_cross_encoding(model)
559
560


561
562
563
564
565
566
def has_step_pooler(model: Union[type[object], object]) -> bool:
    """Check if the model uses step pooler."""
    return is_pooling_model(model) and any(
        type(module).__name__ == "StepPool" for module in model.modules())


567
568
569
class SupportsQuant:
    """The interface required for all models that support quantization."""

570
571
    hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None
    packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None
572
573
    quant_config: Optional[QuantizationConfig] = None

574
    def __new__(cls, *args, **kwargs) -> Self:
575
        instance = super().__new__(cls)
576
577

        # find config passed in arguments
578
579
        quant_config = cls._find_quant_config(*args, **kwargs)
        if quant_config is not None:
580
581

            # attach config to model for general use
582
            instance.quant_config = quant_config
583
584
585
586
587
588
589
590
591
592
593
594

            # apply model mappings to config for proper config-model matching
            # NOTE: `TransformersForCausalLM` is not supported due to how this
            # class defines `hf_to_vllm_mapper` as a post-init `@property`.
            # After this is fixed, get `instance.hf_to_vllm_mapper` directly
            if getattr(instance, "hf_to_vllm_mapper", None) is not None:
                instance.quant_config.apply_vllm_mapper(
                    instance.hf_to_vllm_mapper)
            if getattr(instance, "packed_modules_mapping", None) is not None:
                instance.quant_config.packed_modules_mapping.update(
                    instance.packed_modules_mapping)

595
596
597
598
        return instance

    @staticmethod
    def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
599
        """Find quant config passed through model constructor args"""
600
601
602
603
604
605
606
607
608
609
610
611
612
        from vllm.config import VllmConfig  # avoid circular import

        args_values = list(args) + list(kwargs.values())
        for arg in args_values:
            if isinstance(arg, VllmConfig):
                return arg.quant_config

            if isinstance(arg, QuantizationConfig):
                return arg

        return None


613
614
615
616
617
618
@runtime_checkable
class SupportsTranscription(Protocol):
    """The interface required for all models that support transcription."""

    supports_transcription: ClassVar[Literal[True]] = True

619
620
621
622
623
624
625
626
627
628
629
    @classmethod
    def get_decoder_prompt(cls, language: str, task_type: str,
                           prompt: str) -> str:
        """Get the decoder prompt for the ASR model."""
        ...

    @classmethod
    def validate_language(cls, language: str) -> bool:
        """Check if the model supports a specific ISO639_1 language."""
        ...

630
631
632

@overload
def supports_transcription(
633
        model: type[object]) -> TypeIs[type[SupportsTranscription]]:
634
635
636
637
638
639
640
641
642
    ...


@overload
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
    ...


def supports_transcription(
643
644
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
645
646
647
648
    if isinstance(model, type):
        return isinstance(model, SupportsTranscription)

    return isinstance(model, SupportsTranscription)
649
650
651
652
653
654
655
656
657
658


@runtime_checkable
class SupportsV0Only(Protocol):
    """Models with this interface are not compatible with V1 vLLM."""

    supports_v0_only: ClassVar[Literal[True]] = True


@overload
659
def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]:
660
661
662
663
664
665
666
667
668
    ...


@overload
def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
    ...


def supports_v0_only(
669
670
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
671
672
673
674
    if isinstance(model, type):
        return isinstance(model, SupportsV0Only)

    return isinstance(model, SupportsV0Only)