interfaces.py 14.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
                    Protocol, Type, Union, overload, runtime_checkable)
5

6
import torch
7
from typing_extensions import TypeIs, TypeVar
8
9

from vllm.logger import init_logger
10
11
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
12
from vllm.utils import supports_kw
13

14
from .interfaces_base import is_pooling_model
15

16
if TYPE_CHECKING:
17
18
    from vllm.attention import AttentionMetadata
    from vllm.multimodal.inputs import NestedTensors  # noqa: F401
19
    from vllm.sequence import IntermediateTensors
20

21
22
23
24
if TYPE_CHECKING:
    from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
    from vllm.sequence import IntermediateTensors

25
26
logger = init_logger(__name__)

27
28
T = TypeVar("T", default="NestedTensors")

29
30

@runtime_checkable
31
class SupportsMultiModal(Protocol):
32
    """The interface required for all multi-modal models."""
33

34
    supports_multimodal: ClassVar[Literal[True]] = True
35
    """
36
    A flag that indicates this model supports multi-modal inputs.
37
38
39
40
41

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
42

43
44
45
46
    def get_multimodal_embeddings(self, **kwargs) -> Optional[T]:
        """
        Returns multimodal embeddings generated from multimodal kwargs 
        to be merged with text embeddings.
47
48

        The output embeddings must be one of the following formats:
49
    
50
51
        - A list or tuple of 2D tensors, where each tensor corresponds to
          each input multimodal data item (e.g, image).
52
        - A single 3D tensor, with the batch dimension grouping the 2D tensors.
53

54
        Note:
55
56
            The returned multimodal embeddings must be in the same order as
            the appearances of their corresponding multimodal data item in the
57
            input prompt.
58
59
60
61
62
63
64
65
66
67
68
69
70
71
        """
        ...

    # Only for models that support v0 chunked prefill
    # TODO(ywang96): Remove this overload once v0 is deprecated
    @overload
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[T] = None,
        attn_metadata: Optional["AttentionMetadata"] = None,
    ) -> torch.Tensor:
        ...

72
    @overload
73
74
75
76
77
78
79
80
81
82
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[T] = None,
    ) -> torch.Tensor:
        """
        Returns the input embeddings merged from the text embeddings from 
        input_ids and the multimodal embeddings generated from multimodal 
        kwargs.
        """
83
84
85
86
87
88
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
89
90
class _SupportsMultiModalType(Protocol):
    supports_multimodal: Literal[True]
91
92
93


@overload
94
95
def supports_multimodal(
        model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
96
97
98
99
    ...


@overload
100
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
101
102
103
    ...


104
def supports_multimodal(
105
    model: Union[Type[object], object],
106
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
107
    if isinstance(model, type):
108
        return isinstance(model, _SupportsMultiModalType)
109

110
    return isinstance(model, SupportsMultiModal)
111
112
113
114
115
116


@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

117
118
119
120
121
122
123
124
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144

    packed_modules_mapping: ClassVar[Dict[str, List[str]]]
    supported_lora_modules: ClassVar[List[str]]
    embedding_modules: ClassVar[Dict[str, str]]
    embedding_padding_modules: ClassVar[List[str]]


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

    packed_modules_mapping: Dict[str, List[str]]
    supported_lora_modules: List[str]
    embedding_modules: Dict[str, str]
    embedding_padding_modules: List[str]


@overload
145
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
146
147
148
149
    ...


@overload
150
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
151
152
153
154
155
    ...


def supports_lora(
    model: Union[Type[object], object],
156
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "supported_lora_modules",
            "embedding_modules",
            "embedding_padding_modules",
        )
        missing_attrs = tuple(attr for attr in lora_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
                    "but does not set `supports_lora=True`.", model)

    return result


186
def _supports_lora(model: Union[Type[object], object]) -> bool:
187
188
189
190
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
191
192


193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
        *,
        intermediate_tensors: Optional["IntermediateTensors"],
    ) -> Union[torch.Tensor, "IntermediateTensors"]:
        """
        Accept :class:`IntermediateTensors` when PP rank > 0.

        Return :class:`IntermediateTensors` only for the last PP rank.
        """
        ...
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        ...

    def forward(
        self,
244
        *,
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        intermediate_tensors: Optional["IntermediateTensors"],
    ) -> Union[torch.Tensor, "IntermediateTensors"]:
        ...


@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
    ...


@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
    ...


def supports_pp(
    model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
            "`intermediate_tensors` in its `forward` method", model)

    if not supports_attributes:
        pp_attrs = ("make_empty_intermediate_tensors", )
        missing_attrs = tuple(attr for attr in pp_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
                    "but does not set `supports_pp=True`.", model)

    return supports_attributes and supports_inspect


293
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
294
295
296
297
298
299
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


300
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
301
302
303
304
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

305
    return supports_kw(model_forward, "intermediate_tensors")
306
307


308
309
310
311
312
313
314
315
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
316
        for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
317
318
319
320
321
322
323
324
325
    """


@runtime_checkable
class _HasInnerStateType(Protocol):
    has_inner_state: ClassVar[Literal[True]]


@overload
326
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
327
328
329
330
    ...


@overload
331
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
332
333
334
335
336
    ...


def has_inner_state(
    model: Union[Type[object], object]
337
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
338
339
340
341
    if isinstance(model, type):
        return isinstance(model, _HasInnerStateType)

    return isinstance(model, HasInnerState)
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378


@runtime_checkable
class IsAttentionFree(Protocol):
    """The interface required for all models like Mamba that lack attention,
    but do have state whose size is constant wrt the number of tokens."""

    is_attention_free: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has no attention.
        Used for block manager and attention backend selection.
        True for Mamba but not Jamba.
    """


@runtime_checkable
class _IsAttentionFreeType(Protocol):
    is_attention_free: ClassVar[Literal[True]]


@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
    ...


@overload
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
    ...


def is_attention_free(
    model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
    if isinstance(model, type):
        return isinstance(model, _IsAttentionFreeType)

    return isinstance(model, IsAttentionFree)
379
380


381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
@runtime_checkable
class IsHybrid(Protocol):
    """The interface required for all models like Jamba that have both
    attention and mamba blocks, indicates that 
    hf_config has 'layers_block_type'"""

    is_hybrid: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has both mamba and attention blocks
        , also indicates that the model's hf_config has 
        'layers_block_type' """


@runtime_checkable
class _IsHybridType(Protocol):
    is_hybrid: ClassVar[Literal[True]]


@overload
def is_hybrid(model: object) -> TypeIs[IsHybrid]:
    ...


@overload
def is_hybrid(model: Type[object]) -> TypeIs[Type[IsHybrid]]:
    ...


def is_hybrid(
    model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsHybrid]], TypeIs[IsHybrid]]:
    if isinstance(model, type):
        return isinstance(model, _IsHybridType)

    return isinstance(model, IsHybrid)


418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
@runtime_checkable
class SupportsCrossEncoding(Protocol):
    """The interface required for all models that support cross encoding."""

    supports_cross_encoding: ClassVar[Literal[True]] = True


@overload
def supports_cross_encoding(
        model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]:
    ...


@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
    ...


def _supports_cross_encoding(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:

    if isinstance(model, type):
        return isinstance(model, SupportsCrossEncoding)

    return isinstance(model, SupportsCrossEncoding)


def supports_cross_encoding(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
449
    return is_pooling_model(model) and _supports_cross_encoding(model)
450
451


452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
class SupportsQuant:
    """The interface required for all models that support quantization."""

    packed_modules_mapping: ClassVar[Dict[str, List[str]]] = {}
    quant_config: Optional[QuantizationConfig] = None

    def __new__(cls, *args, **kwargs) -> "SupportsQuant":
        instance = super().__new__(cls)
        quant_config = cls._find_quant_config(*args, **kwargs)
        if quant_config is not None:
            instance.quant_config = quant_config
            instance.quant_config.packed_modules_mapping.update(
                cls.packed_modules_mapping)
        return instance

    @staticmethod
    def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
        from vllm.config import VllmConfig  # avoid circular import

        args_values = list(args) + list(kwargs.values())
        for arg in args_values:
            if isinstance(arg, VllmConfig):
                return arg.quant_config

            if isinstance(arg, QuantizationConfig):
                return arg

        return None


482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
@runtime_checkable
class SupportsTranscription(Protocol):
    """The interface required for all models that support transcription."""

    supports_transcription: ClassVar[Literal[True]] = True


@overload
def supports_transcription(
        model: Type[object]) -> TypeIs[Type[SupportsTranscription]]:
    ...


@overload
def supports_transcription(model: object) -> TypeIs[SupportsTranscription]:
    ...


def supports_transcription(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsTranscription]], TypeIs[SupportsTranscription]]:
    if isinstance(model, type):
        return isinstance(model, SupportsTranscription)

    return isinstance(model, SupportsTranscription)