interfaces.py 15.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
                    Union, overload, runtime_checkable)
5

6
import torch
7
from torch import Tensor
8
from typing_extensions import Self, TypeIs
9
10

from vllm.logger import init_logger
11
12
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
13
from vllm.utils import supports_kw
14

15
from .interfaces_base import is_pooling_model
16

17
if TYPE_CHECKING:
18
    from vllm.attention import AttentionMetadata
19
20
    from vllm.sequence import IntermediateTensors

21
22
logger = init_logger(__name__)

23
24
25
26
27
28
29
30
MultiModalEmbeddings = Union[list[Tensor], Tensor, tuple[Tensor, ...]]
"""
The output embeddings must be one of the following formats:

- A list or tuple of 2D tensors, where each tensor corresponds to
    each input multimodal data item (e.g, image).
- A single 3D tensor, with the batch dimension grouping the 2D tensors.
"""
31

32
33

@runtime_checkable
34
class SupportsMultiModal(Protocol):
35
    """The interface required for all multi-modal models."""
36

37
    supports_multimodal: ClassVar[Literal[True]] = True
38
    """
39
    A flag that indicates this model supports multi-modal inputs.
40
41
42
43
44

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
45

46
47
    def get_multimodal_embeddings(
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
48
49
50
        """
        Returns multimodal embeddings generated from multimodal kwargs 
        to be merged with text embeddings.
51

52
        Note:
53
54
            The returned multimodal embeddings must be in the same order as
            the appearances of their corresponding multimodal data item in the
55
            input prompt.
56
57
58
        """
        ...

59
60
61
62
63
64
65
66
67
68
69
70
    def get_language_model(self) -> torch.nn.Module:
        """
        Returns the underlying language model used for text generation.

        This is typically the `torch.nn.Module` instance responsible for 
        processing the merged multimodal embeddings and producing hidden states

        Returns:
            torch.nn.Module: The core language model component.
        """
        ...

71
72
73
74
75
    # Only for models that support v0 chunked prefill
    # TODO(ywang96): Remove this overload once v0 is deprecated
    @overload
    def get_input_embeddings(
        self,
76
        input_ids: Tensor,
77
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
78
        attn_metadata: Optional["AttentionMetadata"] = None,
79
    ) -> Tensor:
80
81
        ...

82
    @overload
83
84
    def get_input_embeddings(
        self,
85
        input_ids: Tensor,
86
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
87
    ) -> Tensor:
88
89
90
91
92
93
94
        """
        Returns the input embeddings merged from the text embeddings from 
        input_ids and the multimodal embeddings generated from multimodal 
        kwargs.
        """
        ...

95
96
97
98

# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
99
100
class _SupportsMultiModalType(Protocol):
    supports_multimodal: Literal[True]
101
102
103


@overload
104
def supports_multimodal(
105
        model: type[object]) -> TypeIs[type[SupportsMultiModal]]:
106
107
108
109
    ...


@overload
110
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
111
112
113
    ...


114
def supports_multimodal(
115
116
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
117
    if isinstance(model, type):
118
        return isinstance(model, _SupportsMultiModalType)
119

120
    return isinstance(model, SupportsMultiModal)
121
122
123
124
125
126


@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

127
128
129
130
131
132
133
134
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
135
136
    # The `embedding_module` and `embedding_padding_modules`
    # are empty by default.
137
138
139
    embedding_modules: ClassVar[dict[str, str]] = {}
    embedding_padding_modules: ClassVar[list[str]] = []
    packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
140
141
142
143
144
145
146
147


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

148
149
150
    packed_modules_mapping: dict[str, list[str]]
    embedding_modules: dict[str, str]
    embedding_padding_modules: list[str]
151
152
153


@overload
154
def supports_lora(model: type[object]) -> TypeIs[type[SupportsLoRA]]:
155
156
157
158
    ...


@overload
159
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
160
161
162
163
    ...


def supports_lora(
164
165
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "embedding_modules",
            "embedding_padding_modules",
        )
        missing_attrs = tuple(attr for attr in lora_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
                    "but does not set `supports_lora=True`.", model)

    return result


194
def _supports_lora(model: Union[type[object], object]) -> bool:
195
196
197
198
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
199
200


201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
225
        *,
226
        intermediate_tensors: Optional["IntermediateTensors"],
227
    ) -> Union[Tensor, "IntermediateTensors"]:
228
        """
229
230
        Accept [`IntermediateTensors`][vllm.sequence.IntermediateTensors] when
        PP rank > 0.
231

232
233
        Return [`IntermediateTensors`][vllm.sequence.IntermediateTensors] only
        for the last PP rank.
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
        """
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        ...

    def forward(
        self,
254
        *,
255
        intermediate_tensors: Optional["IntermediateTensors"],
256
    ) -> Union[Tensor, "IntermediateTensors"]:
257
258
259
260
        ...


@overload
261
def supports_pp(model: type[object]) -> TypeIs[type[SupportsPP]]:
262
263
264
265
266
267
268
269
270
    ...


@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
    ...


def supports_pp(
271
272
    model: Union[type[object], object],
) -> Union[bool, TypeIs[type[SupportsPP]], TypeIs[SupportsPP]]:
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
            "`intermediate_tensors` in its `forward` method", model)

    if not supports_attributes:
        pp_attrs = ("make_empty_intermediate_tensors", )
        missing_attrs = tuple(attr for attr in pp_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
                    "but does not set `supports_pp=True`.", model)

    return supports_attributes and supports_inspect


303
def _supports_pp_attributes(model: Union[type[object], object]) -> bool:
304
305
306
307
308
309
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


310
def _supports_pp_inspect(model: Union[type[object], object]) -> bool:
311
312
313
314
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

315
    return supports_kw(model_forward, "intermediate_tensors")
316
317


318
319
320
321
322
323
324
325
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
326
        for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
327
328
329
330
331
332
333
334
335
    """


@runtime_checkable
class _HasInnerStateType(Protocol):
    has_inner_state: ClassVar[Literal[True]]


@overload
336
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
337
338
339
340
    ...


@overload
341
def has_inner_state(model: type[object]) -> TypeIs[type[HasInnerState]]:
342
343
344
345
    ...


def has_inner_state(
346
347
    model: Union[type[object], object]
) -> Union[TypeIs[type[HasInnerState]], TypeIs[HasInnerState]]:
348
349
350
351
    if isinstance(model, type):
        return isinstance(model, _HasInnerStateType)

    return isinstance(model, HasInnerState)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377


@runtime_checkable
class IsAttentionFree(Protocol):
    """The interface required for all models like Mamba that lack attention,
    but do have state whose size is constant wrt the number of tokens."""

    is_attention_free: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has no attention.
        Used for block manager and attention backend selection.
        True for Mamba but not Jamba.
    """


@runtime_checkable
class _IsAttentionFreeType(Protocol):
    is_attention_free: ClassVar[Literal[True]]


@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
    ...


@overload
378
def is_attention_free(model: type[object]) -> TypeIs[type[IsAttentionFree]]:
379
380
381
382
    ...


def is_attention_free(
383
384
    model: Union[type[object], object]
) -> Union[TypeIs[type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
385
386
387
388
    if isinstance(model, type):
        return isinstance(model, _IsAttentionFreeType)

    return isinstance(model, IsAttentionFree)
389
390


391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
@runtime_checkable
class IsHybrid(Protocol):
    """The interface required for all models like Jamba that have both
    attention and mamba blocks, indicates that 
    hf_config has 'layers_block_type'"""

    is_hybrid: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has both mamba and attention blocks
        , also indicates that the model's hf_config has 
        'layers_block_type' """


@runtime_checkable
class _IsHybridType(Protocol):
    is_hybrid: ClassVar[Literal[True]]


@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
    ...


@overload
415
def is_hybrid(model: type[object]) -> TypeIs[type[IsHybrid]]:
416
417
418
419
    ...


def is_hybrid(
420
421
    model: Union[type[object], object]
) -> Union[TypeIs[type[IsHybrid]], TypeIs[IsHybrid]]:
422
423
424
425
426
427
    if isinstance(model, type):
        return isinstance(model, _IsHybridType)

    return isinstance(model, IsHybrid)


428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
@runtime_checkable
class HasNoOps(Protocol):
    has_noops: ClassVar[Literal[True]] = True


@runtime_checkable
class _HasNoOpsType(Protocol):
    has_noops: ClassVar[Literal[True]]


@overload
def has_noops(model: object) -> TypeIs[HasNoOps]:
    ...


@overload
444
def has_noops(model: type[object]) -> TypeIs[type[HasNoOps]]:
445
446
447
448
    ...


def has_noops(
449
450
    model: Union[type[object], object]
) -> Union[TypeIs[type[HasNoOps]], TypeIs[HasNoOps]]:
451
452
453
454
455
456
    if isinstance(model, type):
        return isinstance(model, _HasNoOpsType)

    return isinstance(model, HasNoOps)


457
458
459
460
461
462
463
464
465
@runtime_checkable
class SupportsCrossEncoding(Protocol):
    """The interface required for all models that support cross encoding."""

    supports_cross_encoding: ClassVar[Literal[True]] = True


@overload
def supports_cross_encoding(
466
        model: type[object]) -> TypeIs[type[SupportsCrossEncoding]]:
467
468
469
470
471
472
473
474
475
    ...


@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
    ...


def _supports_cross_encoding(
476
477
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
478
479
480
481
482
483
484
485

    if isinstance(model, type):
        return isinstance(model, SupportsCrossEncoding)

    return isinstance(model, SupportsCrossEncoding)


def supports_cross_encoding(
486
487
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
488
    return is_pooling_model(model) and _supports_cross_encoding(model)
489
490


491
492
493
class SupportsQuant:
    """The interface required for all models that support quantization."""

494
    packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
495
496
    quant_config: Optional[QuantizationConfig] = None

497
    def __new__(cls, *args, **kwargs) -> Self:
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
        instance = super().__new__(cls)
        quant_config = cls._find_quant_config(*args, **kwargs)
        if quant_config is not None:
            instance.quant_config = quant_config
            instance.quant_config.packed_modules_mapping.update(
                cls.packed_modules_mapping)
        return instance

    @staticmethod
    def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
        from vllm.config import VllmConfig  # avoid circular import

        args_values = list(args) + list(kwargs.values())
        for arg in args_values:
            if isinstance(arg, VllmConfig):
                return arg.quant_config

            if isinstance(arg, QuantizationConfig):
                return arg

        return None


521
522
523
524
525
526
527
528
529
@runtime_checkable
class SupportsTranscription(Protocol):
    """The interface required for all models that support transcription."""

    supports_transcription: ClassVar[Literal[True]] = True


@overload
def supports_transcription(
530
        model: type[object]) -> TypeIs[type[SupportsTranscription]]:
531
532
533
534
535
536
537
538
539
    ...


@overload
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
    ...


def supports_transcription(
540
541
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
542
543
544
545
    if isinstance(model, type):
        return isinstance(model, SupportsTranscription)

    return isinstance(model, SupportsTranscription)
546
547
548
549
550
551
552
553
554
555


@runtime_checkable
class SupportsV0Only(Protocol):
    """Models with this interface are not compatible with V1 vLLM."""

    supports_v0_only: ClassVar[Literal[True]] = True


@overload
556
def supports_v0_only(model: type[object]) -> TypeIs[type[SupportsV0Only]]:
557
558
559
560
561
562
563
564
565
    ...


@overload
def supports_v0_only(model: object) -> TypeIs[SupportsV0Only]:
    ...


def supports_v0_only(
566
567
    model: Union[type[object], object],
) -> Union[TypeIs[type[SupportsV0Only]], TypeIs[SupportsV0Only]]:
568
569
570
571
    if isinstance(model, type):
        return isinstance(model, SupportsV0Only)

    return isinstance(model, SupportsV0Only)