"tests/plugins_tests/conftest.py" did not exist on "8c0d15d5c5658b74a70694124af2ac250fdc4e23"
interfaces_base.py 4.26 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from typing import (TYPE_CHECKING, Optional, Protocol, Union, overload,
4
                    runtime_checkable)
5
6
7
8
9
10
11
12
13

import torch
import torch.nn as nn
from typing_extensions import TypeIs, TypeVar

from vllm.logger import init_logger
from vllm.utils import supports_kw

if TYPE_CHECKING:
14
    from vllm.config import VllmConfig
15
16
17
18
19
20
21
22
    from vllm.model_executor.layers.pooler import PoolerOutput
    from vllm.model_executor.pooling_metadata import PoolingMetadata
    from vllm.model_executor.sampling_metadata import SamplingMetadata

logger = init_logger(__name__)

# The type of hidden states
# Currently, T = torch.Tensor for all models except for Medusa
23
# which has T = list[torch.Tensor]
24
25
26
27
28
29
30
31
32
T = TypeVar("T", default=torch.Tensor)
T_co = TypeVar("T_co", default=torch.Tensor, covariant=True)

# NOTE: Unlike those in `interfaces.py`, we don't define `ClassVar` tags
# for the base interfaces to avoid breaking OOT registration for existing models
# that don't inherit from the base interface classes


@runtime_checkable
33
class VllmModel(Protocol[T_co]):
34
    """The interface required for all models in vLLM."""
35
36
37

    def __init__(
        self,
38
39
        vllm_config: "VllmConfig",
        prefix: str = "",
40
41
42
43
44
45
46
47
48
49
50
    ) -> None:
        ...

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
    ) -> T_co:
        ...


51
def _check_vllm_model_init(model: Union[type[object], object]) -> bool:
52
    model_init = model.__init__
53
    return supports_kw(model_init, "vllm_config")
54
55


56
def _check_vllm_model_forward(model: Union[type[object], object]) -> bool:
57
58
59
60
    model_forward = getattr(model, "forward", None)
    if not callable(model_forward):
        return False

61
    vllm_kws = ("input_ids", "positions")
62
63
64
65
66
67
68
    missing_kws = tuple(kw for kw in vllm_kws
                        if not supports_kw(model_forward, kw))

    if missing_kws and (isinstance(model, type)
                        and issubclass(model, nn.Module)):
        logger.warning(
            "The model (%s) is missing "
69
            "vLLM-specific keywords from its `forward` method: %s",
70
71
72
73
74
75
76
77
            model,
            missing_kws,
        )

    return len(missing_kws) == 0


@overload
78
def is_vllm_model(model: type[object]) -> TypeIs[type[VllmModel]]:
79
80
81
82
83
84
85
86
87
    ...


@overload
def is_vllm_model(model: object) -> TypeIs[VllmModel]:
    ...


def is_vllm_model(
88
89
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModel]], TypeIs[VllmModel]]:
90
91
92
93
    return _check_vllm_model_init(model) and _check_vllm_model_forward(model)


@runtime_checkable
94
class VllmModelForTextGeneration(VllmModel[T], Protocol[T]):
95
    """The interface required for all generative models in vLLM."""
96
97
98
99
100
101
102
103
104
105
106
107

    def compute_logits(
        self,
        hidden_states: T,
        sampling_metadata: "SamplingMetadata",
    ) -> Optional[T]:
        """Return `None` if TP rank > 0."""
        ...


@overload
def is_text_generation_model(
108
        model: type[object]) -> TypeIs[type[VllmModelForTextGeneration]]:
109
110
111
112
113
114
115
116
117
118
    ...


@overload
def is_text_generation_model(
        model: object) -> TypeIs[VllmModelForTextGeneration]:
    ...


def is_text_generation_model(
119
120
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForTextGeneration]],
121
122
123
124
125
126
127
128
129
130
131
           TypeIs[VllmModelForTextGeneration]]:
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
        return isinstance(model, VllmModelForTextGeneration)

    return isinstance(model, VllmModelForTextGeneration)


@runtime_checkable
132
class VllmModelForPooling(VllmModel[T], Protocol[T]):
133
    """The interface required for all pooling models in vLLM."""
134
135
136
137
138
139
140
141
142
143
144

    def pooler(
        self,
        hidden_states: T,
        pooling_metadata: "PoolingMetadata",
    ) -> "PoolerOutput":
        """Only called on TP rank 0."""
        ...


@overload
145
def is_pooling_model(model: type[object]) -> TypeIs[type[VllmModelForPooling]]:
146
147
148
149
    ...


@overload
150
def is_pooling_model(model: object) -> TypeIs[VllmModelForPooling]:
151
152
153
    ...


154
def is_pooling_model(
155
156
    model: Union[type[object], object],
) -> Union[TypeIs[type[VllmModelForPooling]], TypeIs[VllmModelForPooling]]:
157
158
159
160
    if not is_vllm_model(model):
        return False

    if isinstance(model, type):
161
        return isinstance(model, VllmModelForPooling)
162

163
    return isinstance(model, VllmModelForPooling)