"vllm/compilation/passes/pass_manager.py" did not exist on "fc0f41d10aca510658a4d86c8bff2e6781d5d669"
registry.py 14.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import importlib
import string
import subprocess
import sys
import uuid
from functools import lru_cache, partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

import torch.nn as nn

from vllm.logger import init_logger
from vllm.utils import is_hip

from .interfaces import supports_multimodal, supports_pp

logger = init_logger(__name__)

_GENERATION_MODELS = {
    "AquilaModel": ("llama", "LlamaForCausalLM"),
    "AquilaForCausalLM": ("llama", "LlamaForCausalLM"),  # AquilaChat2
    "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
    "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"),  # baichuan-7b
    "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"),  # baichuan-13b
    "BloomForCausalLM": ("bloom", "BloomForCausalLM"),
    "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
    "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
    "CohereForCausalLM": ("commandr", "CohereForCausalLM"),
    "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
    "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
    "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
    "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
    "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
    "FalconForCausalLM": ("falcon", "FalconForCausalLM"),
    "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
    "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
    "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
    "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
    "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
    "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
    "GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
    "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
    "InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
    "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
    "JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
    "JambaForCausalLM": ("jamba", "JambaForCausalLM"),
    "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
    # For decapoda-research/llama-*
    "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
    "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
    "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
    "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
    # transformers's mpt class has lower case
    "MptForCausalLM": ("mpt", "MPTForCausalLM"),
    "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
    "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
    "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
    "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
    "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
    "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
    "OPTForCausalLM": ("opt", "OPTForCausalLM"),
    "OrionForCausalLM": ("orion", "OrionForCausalLM"),
    "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
    "PhiForCausalLM": ("phi", "PhiForCausalLM"),
    "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
    "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
    "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
    "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
    "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
    "Qwen2VLForConditionalGeneration":
    ("qwen2_vl", "Qwen2VLForConditionalGeneration"),
    "RWForCausalLM": ("falcon", "FalconForCausalLM"),
    "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
    "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
    "SolarForCausalLM": ("solar", "SolarForCausalLM"),
    "XverseForCausalLM": ("xverse", "XverseForCausalLM"),
    # NOTE: The below models are for speculative decoding only
    "MedusaModel": ("medusa", "Medusa"),
    "EAGLEModel": ("eagle", "EAGLE"),
    "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}

_EMBEDDING_MODELS = {
    "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
    "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
}

_MULTIMODAL_MODELS = {
    "Blip2ForConditionalGeneration":
    ("blip2", "Blip2ForConditionalGeneration"),
    "ChameleonForConditionalGeneration":
    ("chameleon", "ChameleonForConditionalGeneration"),
    "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
    "InternVLChatModel": ("internvl", "InternVLChatModel"),
    "LlavaForConditionalGeneration": ("llava",
                                      "LlavaForConditionalGeneration"),
    "LlavaNextForConditionalGeneration": ("llava_next",
                                          "LlavaNextForConditionalGeneration"),
    "LlavaNextVideoForConditionalGeneration":
    ("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
    "LlavaOnevisionForConditionalGeneration":
    ("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
    "MiniCPMV": ("minicpmv", "MiniCPMV"),
    "PaliGemmaForConditionalGeneration": ("paligemma",
                                          "PaliGemmaForConditionalGeneration"),
    "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
    "PixtralForConditionalGeneration": ("pixtral",
                                        "PixtralForConditionalGeneration"),
    "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
    "Qwen2VLForConditionalGeneration": ("qwen2_vl",
                                        "Qwen2VLForConditionalGeneration"),
    "UltravoxModel": ("ultravox", "UltravoxModel"),
    "MllamaForConditionalGeneration": ("mllama",
                                       "MllamaForConditionalGeneration"),
}
_CONDITIONAL_GENERATION_MODELS = {
    "BartModel": ("bart", "BartForConditionalGeneration"),
    "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}

_MODELS = {
    **_GENERATION_MODELS,
    **_EMBEDDING_MODELS,
    **_MULTIMODAL_MODELS,
    **_CONDITIONAL_GENERATION_MODELS,
}

128
# Architecture -> type or (module, class).
129
130
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}
131
_OOT_MODELS_LAZY: Dict[str, Tuple[str, str]] = {}
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
                    "Triton flash attention. For half-precision SWA support, "
                    "please use CK flash attention by setting "
                    "`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
    "Qwen2ForCausalLM":
    _ROCM_SWA_REASON,
    "MistralForCausalLM":
    _ROCM_SWA_REASON,
    "MixtralForCausalLM":
    _ROCM_SWA_REASON,
    "PaliGemmaForConditionalGeneration":
    ("ROCm flash attention does not yet "
     "fully support 32-bit precision on PaliGemma"),
    "Phi3VForCausalLM":
    ("ROCm Triton flash attention may run into compilation errors due to "
     "excessive use of shared memory. If this happens, disable Triton FA "
     "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


class ModelRegistry:

    @staticmethod
    def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
163
164
165
166
167
168
169
170
        if model_arch in _MODELS:
            module_relname, cls_name = _MODELS[model_arch]
            return f"vllm.model_executor.models.{module_relname}", cls_name

        if model_arch in _OOT_MODELS_LAZY:
            return _OOT_MODELS_LAZY[model_arch]

        raise KeyError(model_arch)
171
172
173
174

    @staticmethod
    @lru_cache(maxsize=128)
    def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
175
176
177
        try:
            mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
        except KeyError:
178
179
            return None

180
        module = importlib.import_module(mod_name)
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        return getattr(module, cls_name, None)

    @staticmethod
    def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
        if model_arch in _OOT_MODELS:
            return _OOT_MODELS[model_arch]

        if is_hip():
            if model_arch in _ROCM_UNSUPPORTED_MODELS:
                raise ValueError(
                    f"Model architecture {model_arch} is not supported by "
                    "ROCm for now.")
            if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
                logger.warning(
                    "Model architecture %s is partially supported by ROCm: %s",
                    model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])

        return None

    @staticmethod
    def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
        model = ModelRegistry._try_get_model_stateless(model_arch)
        if model is not None:
            return model

        return ModelRegistry._try_get_model_stateful(model_arch)

    @staticmethod
    def resolve_model_cls(
        architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

        for arch in architectures:
            model_cls = ModelRegistry._try_load_model_cls(arch)
            if model_cls is not None:
                return (model_cls, arch)

        raise ValueError(
            f"Model architectures {architectures} are not supported for now. "
            f"Supported architectures: {ModelRegistry.get_supported_archs()}")

    @staticmethod
    def get_supported_archs() -> List[str]:
        return list(_MODELS.keys()) + list(_OOT_MODELS.keys())

    @staticmethod
230
231
232
233
234
235
236
237
238
239
240
241
242
    def register_model(model_arch: str, model_cls: Union[Type[nn.Module],
                                                         str]):
        """
        Register an external model to be used in vLLM.

        :code:`model_cls` can be either:

        - A :class:`torch.nn.Module` class directly referencing the model.
        - A string in the format :code:`<module>:<class>` which can be used to
          lazily import the model. This is useful to avoid initializing CUDA
          when importing the model and thus the related error
          :code:`RuntimeError: Cannot re-initialize CUDA in forked subprocess`.
        """
243
244
245
246
        if model_arch in _MODELS:
            logger.warning(
                "Model architecture %s is already registered, and will be "
                "overwritten by the new model class %s.", model_arch,
247
248
249
250
251
252
253
                model_cls)

        if isinstance(model_cls, str):
            split_str = model_cls.split(":")
            if len(split_str) != 2:
                msg = "Expected a string in the format `<module>:<class>`"
                raise ValueError(msg)
254

255
256
257
258
            module_name, cls_name = split_str
            _OOT_MODELS_LAZY[model_arch] = module_name, cls_name
        else:
            _OOT_MODELS[model_arch] = model_cls
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

    @staticmethod
    @lru_cache(maxsize=128)
    def _check_stateless(
        func: Callable[[Type[nn.Module]], bool],
        model_arch: str,
        *,
        default: Optional[bool] = None,
    ) -> bool:
        """
        Run a boolean function against a model and return the result.

        If the model is not found, returns the provided default value.

        If the model is not already imported, the function is run inside a
        subprocess to avoid initializing CUDA for the main program.
        """
        model = ModelRegistry._try_get_model_stateless(model_arch)
        if model is not None:
            return func(model)

280
281
282
283
284
        try:
            mod_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
        except KeyError:
            if default is not None:
                return default
285

286
            raise
287
288

        valid_name_characters = string.ascii_letters + string.digits + "._"
289
        if any(s not in valid_name_characters for s in mod_name):
290
291
292
293
294
295
296
297
298
299
300
            raise ValueError(f"Unsafe module name detected for {model_arch}")
        if any(s not in valid_name_characters for s in cls_name):
            raise ValueError(f"Unsafe class name detected for {model_arch}")
        if any(s not in valid_name_characters for s in func.__module__):
            raise ValueError(f"Unsafe module name detected for {func}")
        if any(s not in valid_name_characters for s in func.__name__):
            raise ValueError(f"Unsafe class name detected for {func}")

        err_id = uuid.uuid4()

        stmts = ";".join([
301
            f"from {mod_name} import {cls_name}",
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
            f"from {func.__module__} import {func.__name__}",
            f"assert {func.__name__}({cls_name}), '{err_id}'",
        ])

        result = subprocess.run([sys.executable, "-c", stmts],
                                capture_output=True)

        if result.returncode != 0:
            err_lines = [line.decode() for line in result.stderr.splitlines()]
            if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
                err_str = "\n".join(err_lines)
                raise RuntimeError(
                    "An unexpected error occurred while importing the model in "
                    f"another process. Error log:\n{err_str}")

        return result.returncode == 0

    @staticmethod
    def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

        return any(arch in _EMBEDDING_MODELS for arch in architectures)

    @staticmethod
    def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

        is_mm = partial(ModelRegistry._check_stateless,
                        supports_multimodal,
                        default=False)

        return any(is_mm(arch) for arch in architectures)

    @staticmethod
    def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
        if isinstance(architectures, str):
            architectures = [architectures]
        if not architectures:
            logger.warning("No model architectures are specified")

        is_pp = partial(ModelRegistry._check_stateless,
                        supports_pp,
                        default=False)

        return any(is_pp(arch) for arch in architectures)