interfaces.py 10.7 KB
Newer Older
1
2
from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional,
                    Protocol, Type, Union, overload, runtime_checkable)
3

4
import torch
5
from typing_extensions import TypeIs
6
7

from vllm.logger import init_logger
8
from vllm.utils import supports_kw
9

10
11
from .interfaces_base import is_embedding_model

12
13
14
15
if TYPE_CHECKING:
    from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig
    from vllm.sequence import IntermediateTensors

16
17
18
19
logger = init_logger(__name__)


@runtime_checkable
20
class SupportsMultiModal(Protocol):
21
    """The interface required for all multi-modal models."""
22

23
    supports_multimodal: ClassVar[Literal[True]] = True
24
    """
25
    A flag that indicates this model supports multi-modal inputs.
26
27
28
29
30

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
31

32
    def __init__(self, *, multimodal_config: "MultiModalConfig") -> None:
33
34
35
36
37
38
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
39
40
class _SupportsMultiModalType(Protocol):
    supports_multimodal: Literal[True]
41

42
    def __call__(self, *, multimodal_config: "MultiModalConfig") -> None:
43
44
45
46
        ...


@overload
47
48
def supports_multimodal(
        model: Type[object]) -> TypeIs[Type[SupportsMultiModal]]:
49
50
51
52
    ...


@overload
53
def supports_multimodal(model: object) -> TypeIs[SupportsMultiModal]:
54
55
56
    ...


57
def supports_multimodal(
58
    model: Union[Type[object], object],
59
) -> Union[TypeIs[Type[SupportsMultiModal]], TypeIs[SupportsMultiModal]]:
60
    if isinstance(model, type):
61
        return isinstance(model, _SupportsMultiModalType)
62

63
    return isinstance(model, SupportsMultiModal)
64
65
66
67
68
69


@runtime_checkable
class SupportsLoRA(Protocol):
    """The interface required for all models that support LoRA."""

70
71
72
73
74
75
76
77
    supports_lora: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports LoRA.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """
78
79
80
81
82
83
84

    packed_modules_mapping: ClassVar[Dict[str, List[str]]]
    supported_lora_modules: ClassVar[List[str]]
    embedding_modules: ClassVar[Dict[str, str]]
    embedding_padding_modules: ClassVar[List[str]]

    # lora_config is None when LoRA is not enabled
85
    def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsLoRAType(Protocol):
    supports_lora: Literal[True]

    packed_modules_mapping: Dict[str, List[str]]
    supported_lora_modules: List[str]
    embedding_modules: Dict[str, str]
    embedding_padding_modules: List[str]

100
    def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None:
101
102
103
104
        ...


@overload
105
def supports_lora(model: Type[object]) -> TypeIs[Type[SupportsLoRA]]:
106
107
108
109
    ...


@overload
110
def supports_lora(model: object) -> TypeIs[SupportsLoRA]:
111
112
113
114
115
    ...


def supports_lora(
    model: Union[Type[object], object],
116
) -> Union[TypeIs[Type[SupportsLoRA]], TypeIs[SupportsLoRA]]:
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    result = _supports_lora(model)

    if not result:
        lora_attrs = (
            "packed_modules_mapping",
            "supported_lora_modules",
            "embedding_modules",
            "embedding_padding_modules",
        )
        missing_attrs = tuple(attr for attr in lora_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_lora", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_lora=True`, "
                    "but is missing LoRA-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all LoRA-specific attributes, "
                    "but does not set `supports_lora=True`.", model)

    return result


146
def _supports_lora(model: Union[Type[object], object]) -> bool:
147
148
149
150
    if isinstance(model, type):
        return isinstance(model, _SupportsLoRAType)

    return isinstance(model, SupportsLoRA)
151
152


153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
@runtime_checkable
class SupportsPP(Protocol):
    """The interface required for all models that support pipeline parallel."""

    supports_pp: ClassVar[Literal[True]] = True
    """
    A flag that indicates this model supports pipeline parallel.

    Note:
        There is no need to redefine this flag if this class is in the
        MRO of your model class.
    """

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        """Called when PP rank > 0 for profiling purposes."""
        ...

    def forward(
        self,
177
        *,
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
        intermediate_tensors: Optional["IntermediateTensors"],
    ) -> Union[torch.Tensor, "IntermediateTensors"]:
        """
        Accept :class:`IntermediateTensors` when PP rank > 0.

        Return :class:`IntermediateTensors` only for the last PP rank.
        """
        ...


# We can't use runtime_checkable with ClassVar for issubclass checks
# so we need to treat the class as an instance and use isinstance instead
@runtime_checkable
class _SupportsPPType(Protocol):
    supports_pp: Literal[True]

    def make_empty_intermediate_tensors(
        self,
        batch_size: int,
        dtype: torch.dtype,
        device: torch.device,
    ) -> "IntermediateTensors":
        ...

    def forward(
        self,
204
        *,
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
        intermediate_tensors: Optional["IntermediateTensors"],
    ) -> Union[torch.Tensor, "IntermediateTensors"]:
        ...


@overload
def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]:
    ...


@overload
def supports_pp(model: object) -> TypeIs[SupportsPP]:
    ...


def supports_pp(
    model: Union[Type[object], object],
) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]:
    supports_attributes = _supports_pp_attributes(model)
    supports_inspect = _supports_pp_inspect(model)

    if supports_attributes and not supports_inspect:
        logger.warning(
            "The model (%s) sets `supports_pp=True`, but does not accept "
            "`intermediate_tensors` in its `forward` method", model)

    if not supports_attributes:
        pp_attrs = ("make_empty_intermediate_tensors", )
        missing_attrs = tuple(attr for attr in pp_attrs
                              if not hasattr(model, attr))

        if getattr(model, "supports_pp", False):
            if missing_attrs:
                logger.warning(
                    "The model (%s) sets `supports_pp=True`, "
                    "but is missing PP-specific attributes: %s",
                    model,
                    missing_attrs,
                )
        else:
            if not missing_attrs:
                logger.warning(
                    "The model (%s) contains all PP-specific attributes, "
                    "but does not set `supports_pp=True`.", model)

    return supports_attributes and supports_inspect


253
def _supports_pp_attributes(model: Union[Type[object], object]) -> bool:
254
255
256
257
258
259
    if isinstance(model, type):
        return isinstance(model, _SupportsPPType)

    return isinstance(model, SupportsPP)


260
def _supports_pp_inspect(model: Union[Type[object], object]) -> bool:
261
262
263
264
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

265
    return supports_kw(model_forward, "intermediate_tensors")
266
267


268
269
270
271
272
273
274
275
@runtime_checkable
class HasInnerState(Protocol):
    """The interface required for all models that has inner state."""

    has_inner_state: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has inner state.
        Models that has inner state usually need access to the scheduler_config
276
        for max_num_seqs, etc. True for e.g. both Mamba and Jamba.
277
278
279
280
    """

    def __init__(self,
                 *,
281
                 scheduler_config: Optional["SchedulerConfig"] = None) -> None:
282
283
284
285
286
287
288
289
290
        ...


@runtime_checkable
class _HasInnerStateType(Protocol):
    has_inner_state: ClassVar[Literal[True]]

    def __init__(self,
                 *,
291
                 scheduler_config: Optional["SchedulerConfig"] = None) -> None:
292
293
294
295
        ...


@overload
296
def has_inner_state(model: object) -> TypeIs[HasInnerState]:
297
298
299
300
    ...


@overload
301
def has_inner_state(model: Type[object]) -> TypeIs[Type[HasInnerState]]:
302
303
304
305
306
    ...


def has_inner_state(
    model: Union[Type[object], object]
307
) -> Union[TypeIs[Type[HasInnerState]], TypeIs[HasInnerState]]:
308
309
310
311
    if isinstance(model, type):
        return isinstance(model, _HasInnerStateType)

    return isinstance(model, HasInnerState)
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354


@runtime_checkable
class IsAttentionFree(Protocol):
    """The interface required for all models like Mamba that lack attention,
    but do have state whose size is constant wrt the number of tokens."""

    is_attention_free: ClassVar[Literal[True]] = True
    """
        A flag that indicates this model has no attention.
        Used for block manager and attention backend selection.
        True for Mamba but not Jamba.
    """

    def __init__(self) -> None:
        ...


@runtime_checkable
class _IsAttentionFreeType(Protocol):
    is_attention_free: ClassVar[Literal[True]]

    def __init__(self) -> None:
        ...


@overload
def is_attention_free(model: object) -> TypeIs[IsAttentionFree]:
    ...


@overload
def is_attention_free(model: Type[object]) -> TypeIs[Type[IsAttentionFree]]:
    ...


def is_attention_free(
    model: Union[Type[object], object]
) -> Union[TypeIs[Type[IsAttentionFree]], TypeIs[IsAttentionFree]]:
    if isinstance(model, type):
        return isinstance(model, _IsAttentionFreeType)

    return isinstance(model, IsAttentionFree)
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388


@runtime_checkable
class SupportsCrossEncoding(Protocol):
    """The interface required for all models that support cross encoding."""

    supports_cross_encoding: ClassVar[Literal[True]] = True


@overload
def supports_cross_encoding(
        model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]:
    ...


@overload
def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]:
    ...


def _supports_cross_encoding(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:

    if isinstance(model, type):
        return isinstance(model, SupportsCrossEncoding)

    return isinstance(model, SupportsCrossEncoding)


def supports_cross_encoding(
    model: Union[Type[object], object],
) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]:
    return is_embedding_model(model) and _supports_cross_encoding(model)