utils.py 6.97 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Utilities for selecting and loading models."""
import contextlib
4
from dataclasses import dataclass, field
5
from typing import Dict, List, Optional, Tuple, Type
6
7

import torch
8
import transformers
9
from torch import nn
10
from transformers.dynamic_module_utils import get_class_from_dynamic_module
11

12
13
from vllm.config import ModelConfig, ModelImpl
from vllm.logger import init_logger
14
15
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
16
from vllm.model_executor.models import ModelRegistry
17
18
19
from vllm.model_executor.models.adapters import (as_classification_model,
                                                 as_embedding_model,
                                                 as_reward_model)
20

21
22
logger = init_logger(__name__)

23
24
25
26
27
28
29
30
31
32

@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


33
34
def is_transformers_impl_compatible(
        arch: str,
35
        module: Optional["transformers.PreTrainedModel"] = None) -> bool:
36
37
38
    mod = module or getattr(transformers, arch, None)
    if mod is None:
        return False
39
    return mod.is_backend_compatible()
40
41


42
43
def resolve_transformers_arch(model_config: ModelConfig,
                              architectures: list[str]):
44
    for i, arch in enumerate(architectures):
45
        if arch == "TransformersForCausalLM":
46
            continue
47
48
49
50
51
52
53
54
55
56
57
        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()
        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        auto_modules = {
58
59
60
61
            name:
            get_class_from_dynamic_module(module,
                                          model_config.model,
                                          revision=model_config.revision)
62
63
64
            for name, module in sorted(auto_map.items(), key=lambda x: x[0])
        }
        custom_model_module = auto_modules.get("AutoModel")
65
66
67
        # TODO(Isotr0py): Further clean up these raises.
        # perhaps handled them in _ModelRegistry._raise_for_unsupported?
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
68
            if not is_transformers_impl_compatible(arch, custom_model_module):
69
70
71
                raise ValueError(
                    f"The Transformers implementation of {arch} is not "
                    "compatible with vLLM.")
72
            architectures[i] = "TransformersForCausalLM"
73
        if model_config.model_impl == ModelImpl.AUTO:
74
            if not is_transformers_impl_compatible(arch, custom_model_module):
75
76
                raise ValueError(
                    f"{arch} has no vLLM implementation and the Transformers "
77
78
                    "implementation is not compatible with vLLM. Try setting "
                    "VLLM_USE_V1=0.")
79
80
81
82
            logger.warning(
                "%s has no vLLM implementation, falling back to Transformers "
                "implementation. Some features may not be supported and "
                "performance may not be optimal.", arch)
83
            architectures[i] = "TransformersForCausalLM"
84
85
86
    return architectures


87
def get_model_architecture(
88
89
        model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
    architectures = getattr(model_config.hf_config, "architectures", [])
90

91
92
    # Special handling for quantized Mixtral.
    # FIXME(woosuk): This is a temporary hack.
93
94
95
    mixtral_supported = [
        "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
    ]
96

97
    if (model_config.quantization is not None
98
            and model_config.quantization not in mixtral_supported
99
100
            and "MixtralForCausalLM" in architectures):
        architectures = ["QuantMixtralForCausalLM"]
101

102
    vllm_supported_archs = ModelRegistry.get_supported_archs()
103
104
105
106
    vllm_not_supported = not any(arch in vllm_supported_archs
                                 for arch in architectures)
    if (model_config.model_impl == ModelImpl.TRANSFORMERS or
            model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
107
        architectures = resolve_transformers_arch(model_config, architectures)
108

109
    model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
110
    if model_config.task == "embed":
111
        model_cls = as_embedding_model(model_cls)
112
113
114
115
    elif model_config.task == "classify":
        model_cls = as_classification_model(model_cls)
    elif model_config.task == "reward":
        model_cls = as_reward_model(model_cls)
116
117

    return model_cls, arch
118
119
120
121


def get_architecture_class_name(model_config: ModelConfig) -> str:
    return get_model_architecture(model_config)[1]
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144


@dataclass
class ParamMapping:
    """
    A class to handle parameter mapping for model weight loading.
    It creates a bidirectional mapping between packed parameters and their 
    constituent parts.
    """
    packed_mapping: Dict[str, List[str]]
    inverse_packed_mapping: Dict[str, Tuple[str,
                                            int]] = field(default_factory=dict)

    def __post_init__(self):
        for packed_name, sub_params in self.packed_mapping.items():
            # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
            if len(sub_params) == 1 and sub_params[0] == packed_name:
                continue
            for index, param_name in enumerate(sub_params):
                self.inverse_packed_mapping[param_name] = (
                    packed_name,
                    index,
                )
145
146
147
148
149
150
151

    def get_sub_modules(self,
                        module_name: str) -> Optional[Tuple[str, List[str]]]:
        for key, value in self.packed_mapping.items():
            if module_name.endswith(key):
                return key, value
        return None
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171


def configure_quant_config(quant_config: QuantizationConfig,
                           model_class: Type[nn.Module]):
    """
    Pass packed_modules_mapping by reference to quant_config so that
    quant_config can properly match fused modules

    Note that model attributes are passed by reference to quant_config,
    enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
    """
    packed_mapping = getattr(model_class, "packed_modules_mapping", None)
    if packed_mapping is not None:
        # pass packed_modules_mapping by reference to quant_config
        quant_config.packed_modules_mapping = packed_mapping
    else:
        logger.warning(
            "The model class %s has not defined `packed_modules_mapping`, "
            "this may lead to incorrect mapping of quantized or ignored "
            "modules", model_class.__name__)