"vscode:/vscode.git/clone" did not exist on "3b17ea26e41b16a72935cdab7b7a771bfa1c25ef"
utils.py 16.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Utilities for selecting and loading models."""
import contextlib
5
6
7
import inspect
import warnings
from contextlib import contextmanager
8
from dataclasses import dataclass, field
9
from typing import Optional
10

zhuwenwen's avatar
zhuwenwen committed
11
import os
12
import torch
13
import transformers
14
from torch import nn
15
from transformers.dynamic_module_utils import get_class_from_dynamic_module
16

17
18
19
from vllm.attention import Attention
from vllm.config import (ModelConfig, ModelImpl, VllmConfig,
                         set_current_vllm_config)
20
from vllm.logger import init_logger
21
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
22
from vllm.model_executor.layers.quantization.base_config import (
23
    QuantizationConfig, QuantizeMethodBase)
24
from vllm.model_executor.models import ModelRegistry
25
from vllm.model_executor.models.adapters import (as_embedding_model,
26
27
                                                 as_reward_model,
                                                 as_seq_cls_model)
28
import vllm.envs as envs
29
from vllm.model_executor.models.interfaces import SupportsQuant
30
31
from vllm.model_executor.models.registry import (_PREVIOUSLY_SUPPORTED_MODELS,
                                                 _TRANSFORMERS_MODELS)
32
from vllm.utils import is_pin_memory_available
33

34
35
logger = init_logger(__name__)

36
37
38
39
40
41
42
43
44
45

@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


46
47
48
49
50
def initialize_model(
    vllm_config: VllmConfig,
    *,
    prefix: str = "",
    model_class: Optional[type[nn.Module]] = None,
51
    model_config: Optional[ModelConfig] = None,
52
53
) -> nn.Module:
    """Initialize a model with the given configurations."""
54
55
    if model_config is None:
        model_config = vllm_config.model_config
56
57
58
59
60
61
62
63
64
65
    if model_class is None:
        model_class, _ = get_model_architecture(model_config)

    if vllm_config.quant_config is not None:
        configure_quant_config(vllm_config.quant_config, model_class)

    signatures = inspect.signature(model_class.__init__)
    all_params = [param.name for param in signatures.parameters.values()]
    if "vllm_config" in all_params and "prefix" in all_params:
        # new-style model class
66
67
68
        with set_current_vllm_config(vllm_config,
                                     check_compile=True,
                                     prefix=prefix):
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
            return model_class(vllm_config=vllm_config, prefix=prefix)

    msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
           "input arguments. Possibly you have an old-style model class"
           " registered from out of tree and it is used for new vLLM version. "
           "Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
           "for the design and update the model class accordingly.")
    warnings.warn(msg, DeprecationWarning, stacklevel=2)

    logger.warning(
        "Trying to guess the arguments for old-style model class %s",
        model_class,
    )
    # try to be compatible with old-style model class
    kwargs = {}
    if "prefix" in all_params:
        kwargs["prefix"] = prefix
    if "config" in all_params:
        kwargs["config"] = model_config.hf_config
    if "cache_config" in all_params:
        kwargs["cache_config"] = vllm_config.cache_config
    if "quant_config" in all_params:
        kwargs["quant_config"] = vllm_config.quant_config
    if "lora_config" in all_params:
        kwargs["lora_config"] = vllm_config.lora_config
    if "scheduler_config" in all_params:
        kwargs["scheduler_config"] = vllm_config.scheduler_config
96
97
98
    with set_current_vllm_config(vllm_config,
                                 check_compile=True,
                                 prefix=prefix):
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        return model_class(**kwargs)


def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
                                  target_device: torch.device) -> None:
    for _, module in model.named_modules():
        if isinstance(module, QKVCrossParallelLinear):
            # NOTE(Isotr0py): special case for cross QKV layer because
            # q and kv proj aren't registered as submodules intentionally
            module.process_weights_after_loading()
            continue
        quant_method = getattr(module, "quant_method", None)
        if isinstance(quant_method, QuantizeMethodBase):
            # When quant methods need to process weights after loading
            # (for repacking, quantizing, etc), they expect parameters
            # to be on the global target device. This scope is for the
            # case where cpu offloading is used, where we will move the
            # parameters onto device for processing and back off after.
            with device_loading_context(module, target_device):
                quant_method.process_weights_after_loading(module)

    # Currently only used by MLA.
    # NOTE: This intentionally happens after other modules so we can easily
    # decompress the weights for MLA.
    for _, module in model.named_modules():
        if isinstance(module, Attention) and \
            hasattr(module, "process_weights_after_loading"):
            # TODO(lucas): see if there is a way to unify the signatures
            # of process_weights_after_loading
            module.process_weights_after_loading(model_config.dtype)


@contextmanager
def device_loading_context(module: torch.nn.Module,
                           target_device: torch.device):
    if target_device.type == "cpu":
        # If target is CPU, no need to move anything
        yield module
        return

139
    original_device_states: dict[str, torch.device] = {}
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

    # Store original device states and move parameters to GPU if they're on CPU
    for name, p in module.named_parameters():
        if p.device.type == "cpu":
            original_device_states[name] = p.device
            p.data = p.data.to(target_device)
        # Parameters already on target device are not touched

    try:
        yield module

    finally:
        # Restore parameters to their original devices, ignoring new parameters
        pin_memory = is_pin_memory_available()
        for name, p in module.named_parameters():
            if name in original_device_states:
                original_device: torch.device = original_device_states[name]
                if original_device.type == "cpu":
                    # `torch.empty_like` does not support `pin_memory` argument
                    cpu_data = torch.empty_strided(
                        size=p.data.size(),
                        stride=p.data.stride(),
                        dtype=p.data.dtype,
                        layout=p.data.layout,
                        device="cpu",
                        pin_memory=pin_memory,
                    )
                    cpu_data.copy_(p.data)
                    p.data = cpu_data
                else:
                    p.data = p.data.to(original_device)
        # New parameters or parameters already on target device are untouched


174
175
def resolve_transformers_arch(model_config: ModelConfig,
                              architectures: list[str]):
176
177
178
179
180
181
    if model_config.model_impl == ModelImpl.VLLM:
        raise ValueError(
            "Attempting to resolve architecture from the Transformers library "
            "but the model implementation is set to vLLM. This should never "
            "happen.")

182
    for i, arch in enumerate(architectures):
183
        if arch in _TRANSFORMERS_MODELS:
184
            continue
185
186
187
188
189
190
191

        if model_config.model_impl == ModelImpl.AUTO:
            logger.warning(
                "%s has no vLLM implementation, falling back to Transformers "
                "implementation. Some features may not be supported and "
                "performance may not be optimal.", arch)

192
193
194
195
196
197
198
199
200
201
202
        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()
        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        auto_modules = {
203
204
205
206
            name:
            get_class_from_dynamic_module(module,
                                          model_config.model,
                                          revision=model_config.revision)
207
208
            for name, module in sorted(auto_map.items(), key=lambda x: x[0])
        }
209
210
211
212
213
214
215
216
217
218
        model_module = getattr(transformers, arch, None)
        if model_module is None:
            if "AutoModel" not in auto_map:
                raise ValueError(
                    f"Cannot find model module. '{arch}' is not a registered "
                    "model in the Transformers library (only relevant if the "
                    "model is meant to be in Transformers) and 'AutoModel' is "
                    "not present in the model config's 'auto_map' (relevant "
                    "if the model is custom).")
            model_module = auto_modules["AutoModel"]
219
220
221
222
223
224
225

        if not model_module.is_backend_compatible():
            raise ValueError(
                f"The Transformers implementation of '{arch}' is not "
                "compatible with vLLM.")

        architectures[i] = model_config._get_transformers_backend_cls()
226
227
228
    return architectures


229
def get_model_architecture(
230
        model_config: ModelConfig) -> tuple[type[nn.Module], str]:
231
    architectures = getattr(model_config.hf_config, "architectures", [])
zhuwenwen's avatar
zhuwenwen committed
232
    visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
233
234
    # TODO: 'Qwen2_5_VLForConditionalGeneration', 
    support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'Qwen3ForCausalLM', 'Qwen3MoeForCausalLM',
zhuwenwen's avatar
zhuwenwen committed
235
                                'ChatGLMModel', 'Glm4ForCausalLM', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'TeleChat2ForCausalLM', 'MixtralForCausalLM', 'FalconForCausalLM',
zhuwenwen's avatar
zhuwenwen committed
236
                                'MedusaModel', 'MLPSpeculatorPreTrainedModel', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']  
237
    if any(arch in architectures for arch in support_nn_architectures): 
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        if not envs.VLLM_USE_NN:
            if os.getenv('LLAMA_NN') != '0': 
                if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
                    os.environ['LLAMA_NN'] = '0'
                else:
                    os.environ['LLAMA_NN'] = '1'
            if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
                os.environ['LM_NN'] = '0'
            else:
                os.environ['LM_NN'] = '1'
            if os.getenv('GEMM_PAD') != '1': 
                os.environ['GEMM_PAD'] = '0'
            if os.getenv('FA_PAD') != '1': 
                os.environ['FA_PAD'] = '0'
252
        # awq相关配置
zhuwenwen's avatar
zhuwenwen committed
253
        try:
254
255
256
            if os.getenv('AWQ_MOE_SZ') == None:
                os.environ['AWQ_MOE_SZ'] = '1'
            if os.getenv('AWQ_PAD') == None and (torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120):
zhuwenwen's avatar
zhuwenwen committed
257
258
259
260
261
262
                os.environ['AWQ_PAD'] = '1'
        except Exception as e:
            if os.getenv('AWQ_PAD') != '0': 
                os.environ['AWQ_PAD'] = '1'
            else:
                os.environ['AWQ_PAD'] = '0'
zhuwenwen's avatar
zhuwenwen committed
263
264
    else:
        os.environ['LLAMA_NN'] = '0'
zhuwenwen's avatar
zhuwenwen committed
265
        os.environ['LM_NN'] = '0'
266
267
        os.environ['GEMM_PAD'] = '0'
        os.environ['FA_PAD'] = '0'
zhuwenwen's avatar
zhuwenwen committed
268
        os.environ['AWQ_PAD'] = '0'
269
        
270
271
    # Special handling for quantized Mixtral.
    # FIXME(woosuk): This is a temporary hack.
272
    mixtral_supported = [
273
274
275
276
277
278
        "fp8",
        "compressed-tensors",
        "gptq_marlin",
        "awq_marlin",
        "quark",
        "bitsandbytes",
279
    ]
280

281
    vllm_supported_archs = ModelRegistry.get_supported_archs()
282
283
284
    is_supported = lambda arch: (arch in vllm_supported_archs and arch not in
                                 _TRANSFORMERS_MODELS)
    vllm_not_supported = not any(is_supported(arch) for arch in architectures)
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303

    if vllm_not_supported:
        # try automatic conversion in adapters.py
        for arch in architectures:
            if not arch.endswith("ForSequenceClassification"):
                continue

            assert model_config.task == "classify"
            causal_lm_arch = arch.replace("ForSequenceClassification",
                                          "ForCausalLM")
            causal_lm_arch_vllm_supported = (causal_lm_arch
                                             in vllm_supported_archs)
            if not causal_lm_arch_vllm_supported:
                continue

            architectures = [causal_lm_arch]
            vllm_not_supported = False
            break

304
305
306
307
308
309
310
311
    if any(arch in _PREVIOUSLY_SUPPORTED_MODELS for arch in architectures):
        previous_version = _PREVIOUSLY_SUPPORTED_MODELS[architectures[0]]
        raise ValueError(
            f"Model architecture {architectures[0]} was supported"
            f" in vLLM until version {previous_version}, and is "
            "not supported anymore. Please use an older version"
            " of vLLM if you want to use this model architecture.")

312
    if (model_config.model_impl == ModelImpl.TRANSFORMERS or
313
            model_config.model_impl == ModelImpl.AUTO and vllm_not_supported):
314
        architectures = resolve_transformers_arch(model_config, architectures)
315
        logger.debug_once("Resolve transformers arch %s", str(architectures))
316
317
318
319
    elif (model_config.quantization is not None
          and model_config.quantization not in mixtral_supported
          and "MixtralForCausalLM" in architectures):
        architectures = ["QuantMixtralForCausalLM"]
320

321
    model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
322
    if model_config.task == "embed":
323
        logger.debug_once("Automatic conversion using `as_embedding_model`.")
324
        model_cls = as_embedding_model(model_cls)
325
    elif model_config.task == "classify":
326
327
        logger.debug_once("Automatic conversion using `as_seq_cls_model`.")
        model_cls = as_seq_cls_model(model_cls)
328
    elif model_config.task == "reward":
329
        logger.debug_once("Automatic conversion using `as_reward_model`.")
330
        model_cls = as_reward_model(model_cls)
331
332

    return model_cls, arch
333
334


335
336
337
338
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
    return get_model_architecture(model_config)[0]


339
340
def get_architecture_class_name(model_config: ModelConfig) -> str:
    return get_model_architecture(model_config)[1]
341
342
343
344
345
346
347
348
349


@dataclass
class ParamMapping:
    """
    A class to handle parameter mapping for model weight loading.
    It creates a bidirectional mapping between packed parameters and their 
    constituent parts.
    """
350
351
    packed_mapping: dict[str, list[str]]
    inverse_packed_mapping: dict[str, tuple[str,
352
353
354
355
356
357
358
359
360
361
362
363
                                            int]] = field(default_factory=dict)

    def __post_init__(self):
        for packed_name, sub_params in self.packed_mapping.items():
            # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
            if len(sub_params) == 1 and sub_params[0] == packed_name:
                continue
            for index, param_name in enumerate(sub_params):
                self.inverse_packed_mapping[param_name] = (
                    packed_name,
                    index,
                )
364
365

    def get_sub_modules(self,
366
                        module_name: str) -> Optional[tuple[str, list[str]]]:
367
368
369
370
        for key, value in self.packed_mapping.items():
            if module_name.endswith(key):
                return key, value
        return None
371
372
373


def configure_quant_config(quant_config: QuantizationConfig,
374
                           model_class: type[nn.Module]):
375
376
377
378
379
380
    """
    Pass packed_modules_mapping by reference to quant_config so that
    quant_config can properly match fused modules

    Note that model attributes are passed by reference to quant_config,
    enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
381
382
383

    Once the `SupportsQuant` mixin has been added to all models, this
    function can be removed
384
    """
385
386
387
388
389
390
391
392
393
    if not issubclass(model_class, SupportsQuant):
        hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
        packed_mapping = getattr(model_class, "packed_modules_mapping", None)

        # pass mappings by reference to quant_config
        if hf_to_vllm_mapper is not None:
            quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
        if packed_mapping is not None:
            quant_config.packed_modules_mapping = packed_mapping