utils.py 14.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Utilities for selecting and loading models."""
import contextlib
5
6
7
import inspect
import warnings
from contextlib import contextmanager
8
from dataclasses import dataclass, field
9
from typing import Optional
10
11

import torch
12
import transformers
13
from torch import nn
14
from transformers.dynamic_module_utils import get_class_from_dynamic_module
15

16
17
18
from vllm.attention import Attention
from vllm.config import (ModelConfig, ModelImpl, VllmConfig,
                         set_current_vllm_config)
19
from vllm.logger import init_logger
20
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
21
from vllm.model_executor.layers.quantization.base_config import (
22
    QuantizationConfig, QuantizeMethodBase)
23
from vllm.model_executor.models import ModelRegistry
24
from vllm.model_executor.models.adapters import (as_embedding_model,
25
26
                                                 as_reward_model,
                                                 as_seq_cls_model)
27
from vllm.model_executor.models.interfaces import SupportsQuant
28
from vllm.utils import is_pin_memory_available
29

30
31
logger = init_logger(__name__)

32
33
34
35
36
37
38
39
40
41

@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


42
43
44
45
46
def initialize_model(
    vllm_config: VllmConfig,
    *,
    prefix: str = "",
    model_class: Optional[type[nn.Module]] = None,
47
    model_config: Optional[ModelConfig] = None,
48
49
) -> nn.Module:
    """Initialize a model with the given configurations."""
50
51
    if model_config is None:
        model_config = vllm_config.model_config
52
53
54
55
56
57
58
59
60
61
    if model_class is None:
        model_class, _ = get_model_architecture(model_config)

    if vllm_config.quant_config is not None:
        configure_quant_config(vllm_config.quant_config, model_class)

    signatures = inspect.signature(model_class.__init__)
    all_params = [param.name for param in signatures.parameters.values()]
    if "vllm_config" in all_params and "prefix" in all_params:
        # new-style model class
62
63
64
        with set_current_vllm_config(vllm_config,
                                     check_compile=True,
                                     prefix=prefix):
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
            return model_class(vllm_config=vllm_config, prefix=prefix)

    msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
           "input arguments. Possibly you have an old-style model class"
           " registered from out of tree and it is used for new vLLM version. "
           "Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
           "for the design and update the model class accordingly.")
    warnings.warn(msg, DeprecationWarning, stacklevel=2)

    logger.warning(
        "Trying to guess the arguments for old-style model class %s",
        model_class,
    )
    # try to be compatible with old-style model class
    kwargs = {}
    if "prefix" in all_params:
        kwargs["prefix"] = prefix
    if "config" in all_params:
        kwargs["config"] = model_config.hf_config
    if "cache_config" in all_params:
        kwargs["cache_config"] = vllm_config.cache_config
    if "quant_config" in all_params:
        kwargs["quant_config"] = vllm_config.quant_config
    if "lora_config" in all_params:
        kwargs["lora_config"] = vllm_config.lora_config
    if "scheduler_config" in all_params:
        kwargs["scheduler_config"] = vllm_config.scheduler_config
92
93
94
    with set_current_vllm_config(vllm_config,
                                 check_compile=True,
                                 prefix=prefix):
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
        return model_class(**kwargs)


def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
                                  target_device: torch.device) -> None:
    for _, module in model.named_modules():
        if isinstance(module, QKVCrossParallelLinear):
            # NOTE(Isotr0py): special case for cross QKV layer because
            # q and kv proj aren't registered as submodules intentionally
            module.process_weights_after_loading()
            continue
        quant_method = getattr(module, "quant_method", None)
        if isinstance(quant_method, QuantizeMethodBase):
            # When quant methods need to process weights after loading
            # (for repacking, quantizing, etc), they expect parameters
            # to be on the global target device. This scope is for the
            # case where cpu offloading is used, where we will move the
            # parameters onto device for processing and back off after.
            with device_loading_context(module, target_device):
                quant_method.process_weights_after_loading(module)

    # Currently only used by MLA.
    # NOTE: This intentionally happens after other modules so we can easily
    # decompress the weights for MLA.
    for _, module in model.named_modules():
        if isinstance(module, Attention) and \
            hasattr(module, "process_weights_after_loading"):
            # TODO(lucas): see if there is a way to unify the signatures
            # of process_weights_after_loading
            module.process_weights_after_loading(model_config.dtype)


@contextmanager
def device_loading_context(module: torch.nn.Module,
                           target_device: torch.device):
    if target_device.type == "cpu":
        # If target is CPU, no need to move anything
        yield module
        return

135
    original_device_states: dict[str, torch.device] = {}
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

    # Store original device states and move parameters to GPU if they're on CPU
    for name, p in module.named_parameters():
        if p.device.type == "cpu":
            original_device_states[name] = p.device
            p.data = p.data.to(target_device)
        # Parameters already on target device are not touched

    try:
        yield module

    finally:
        # Restore parameters to their original devices, ignoring new parameters
        pin_memory = is_pin_memory_available()
        for name, p in module.named_parameters():
            if name in original_device_states:
                original_device: torch.device = original_device_states[name]
                if original_device.type == "cpu":
                    # `torch.empty_like` does not support `pin_memory` argument
                    cpu_data = torch.empty_strided(
                        size=p.data.size(),
                        stride=p.data.stride(),
                        dtype=p.data.dtype,
                        layout=p.data.layout,
                        device="cpu",
                        pin_memory=pin_memory,
                    )
                    cpu_data.copy_(p.data)
                    p.data = cpu_data
                else:
                    p.data = p.data.to(original_device)
        # New parameters or parameters already on target device are untouched


170
171
def resolve_transformers_arch(model_config: ModelConfig,
                              architectures: list[str]):
172
    for i, arch in enumerate(architectures):
173
        if arch == "TransformersForCausalLM":
174
            continue
175
176
177
178
179
180
181
182
183
184
185
        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()
        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        auto_modules = {
186
187
188
189
            name:
            get_class_from_dynamic_module(module,
                                          model_config.model,
                                          revision=model_config.revision)
190
191
            for name, module in sorted(auto_map.items(), key=lambda x: x[0])
        }
192
193
194
195
196
197
198
199
200
201
        model_module = getattr(transformers, arch, None)
        if model_module is None:
            if "AutoModel" not in auto_map:
                raise ValueError(
                    f"Cannot find model module. '{arch}' is not a registered "
                    "model in the Transformers library (only relevant if the "
                    "model is meant to be in Transformers) and 'AutoModel' is "
                    "not present in the model config's 'auto_map' (relevant "
                    "if the model is custom).")
            model_module = auto_modules["AutoModel"]
202
203
204
        # TODO(Isotr0py): Further clean up these raises.
        # perhaps handled them in _ModelRegistry._raise_for_unsupported?
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
205
            if not model_module.is_backend_compatible():
206
207
208
                raise ValueError(
                    f"The Transformers implementation of {arch} is not "
                    "compatible with vLLM.")
209
            architectures[i] = "TransformersForCausalLM"
210
        if model_config.model_impl == ModelImpl.AUTO:
211
            if not model_module.is_backend_compatible():
212
213
                raise ValueError(
                    f"{arch} has no vLLM implementation and the Transformers "
214
215
                    "implementation is not compatible with vLLM. Try setting "
                    "VLLM_USE_V1=0.")
216
217
218
219
            logger.warning(
                "%s has no vLLM implementation, falling back to Transformers "
                "implementation. Some features may not be supported and "
                "performance may not be optimal.", arch)
220
            architectures[i] = "TransformersForCausalLM"
221
222
223
    return architectures


224
def get_model_architecture(
225
        model_config: ModelConfig) -> tuple[type[nn.Module], str]:
226
    architectures = getattr(model_config.hf_config, "architectures", [])
227

228
229
    # Special handling for quantized Mixtral.
    # FIXME(woosuk): This is a temporary hack.
230
    mixtral_supported = [
231
232
233
234
235
236
        "fp8",
        "compressed-tensors",
        "gptq_marlin",
        "awq_marlin",
        "quark",
        "bitsandbytes",
237
    ]
238

239
    vllm_supported_archs = ModelRegistry.get_supported_archs()
240
241
    vllm_not_supported = not any(arch in vllm_supported_archs
                                 for arch in architectures)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260

    if vllm_not_supported:
        # try automatic conversion in adapters.py
        for arch in architectures:
            if not arch.endswith("ForSequenceClassification"):
                continue

            assert model_config.task == "classify"
            causal_lm_arch = arch.replace("ForSequenceClassification",
                                          "ForCausalLM")
            causal_lm_arch_vllm_supported = (causal_lm_arch
                                             in vllm_supported_archs)
            if not causal_lm_arch_vllm_supported:
                continue

            architectures = [causal_lm_arch]
            vllm_not_supported = False
            break

261
262
    if (model_config.model_impl == ModelImpl.TRANSFORMERS or
            model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
263
        architectures = resolve_transformers_arch(model_config, architectures)
264
        logger.debug_once("Resolve transformers arch %s", str(architectures))
265
266
267
268
    elif (model_config.quantization is not None
          and model_config.quantization not in mixtral_supported
          and "MixtralForCausalLM" in architectures):
        architectures = ["QuantMixtralForCausalLM"]
269

270
    model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
271
    if model_config.task == "embed":
272
        logger.debug_once("Automatic conversion using `as_embedding_model`.")
273
        model_cls = as_embedding_model(model_cls)
274
    elif model_config.task == "classify":
275
276
        logger.debug_once("Automatic conversion using `as_seq_cls_model`.")
        model_cls = as_seq_cls_model(model_cls)
277
    elif model_config.task == "reward":
278
        logger.debug_once("Automatic conversion using `as_reward_model`.")
279
        model_cls = as_reward_model(model_cls)
280
281

    return model_cls, arch
282
283


284
285
286
287
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
    return get_model_architecture(model_config)[0]


288
289
def get_architecture_class_name(model_config: ModelConfig) -> str:
    return get_model_architecture(model_config)[1]
290
291
292
293
294
295
296
297
298


@dataclass
class ParamMapping:
    """
    A class to handle parameter mapping for model weight loading.
    It creates a bidirectional mapping between packed parameters and their 
    constituent parts.
    """
299
300
    packed_mapping: dict[str, list[str]]
    inverse_packed_mapping: dict[str, tuple[str,
301
302
303
304
305
306
307
308
309
310
311
312
                                            int]] = field(default_factory=dict)

    def __post_init__(self):
        for packed_name, sub_params in self.packed_mapping.items():
            # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
            if len(sub_params) == 1 and sub_params[0] == packed_name:
                continue
            for index, param_name in enumerate(sub_params):
                self.inverse_packed_mapping[param_name] = (
                    packed_name,
                    index,
                )
313
314

    def get_sub_modules(self,
315
                        module_name: str) -> Optional[tuple[str, list[str]]]:
316
317
318
319
        for key, value in self.packed_mapping.items():
            if module_name.endswith(key):
                return key, value
        return None
320
321
322


def configure_quant_config(quant_config: QuantizationConfig,
323
                           model_class: type[nn.Module]):
324
325
326
327
328
329
    """
    Pass packed_modules_mapping by reference to quant_config so that
    quant_config can properly match fused modules

    Note that model attributes are passed by reference to quant_config,
    enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
330
331
332

    Once the `SupportsQuant` mixin has been added to all models, this
    function can be removed
333
    """
334
335
336
337
338
339
340
341
342
    if not issubclass(model_class, SupportsQuant):
        hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
        packed_mapping = getattr(model_class, "packed_modules_mapping", None)

        # pass mappings by reference to quant_config
        if hf_to_vllm_mapper is not None:
            quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
        if packed_mapping is not None:
            quant_config.packed_modules_mapping = packed_mapping