utils.py 9.15 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Utilities for selecting and loading models."""
import contextlib
4
from dataclasses import dataclass, field
5
from typing import Dict, List, Optional, Tuple, Type
6

zhuwenwen's avatar
zhuwenwen committed
7
import os
8
import torch
9
import transformers
10
from torch import nn
11
from transformers.dynamic_module_utils import get_class_from_dynamic_module
12

13
14
from vllm.config import ModelConfig, ModelImpl
from vllm.logger import init_logger
15
16
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
17
from vllm.model_executor.models import ModelRegistry
18
19
20
from vllm.model_executor.models.adapters import (as_classification_model,
                                                 as_embedding_model,
                                                 as_reward_model)
21

22
23
logger = init_logger(__name__)

24
25
26
27
28
29
30
31
32
33

@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


34
35
def resolve_transformers_arch(model_config: ModelConfig,
                              architectures: list[str]):
36
    for i, arch in enumerate(architectures):
37
        if arch == "TransformersForCausalLM":
38
            continue
39
40
41
42
43
44
45
46
47
48
49
        auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
                                           None) or dict()
        # Make sure that config class is always initialized before model class,
        # otherwise the model class won't be able to access the config class,
        # the expected auto_map should have correct order like:
        # "auto_map": {
        #     "AutoConfig": "<your-repo-name>--<config-name>",
        #     "AutoModel": "<your-repo-name>--<config-name>",
        #     "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
        # },
        auto_modules = {
50
51
52
53
            name:
            get_class_from_dynamic_module(module,
                                          model_config.model,
                                          revision=model_config.revision)
54
55
            for name, module in sorted(auto_map.items(), key=lambda x: x[0])
        }
56
57
58
59
60
61
62
63
64
65
        model_module = getattr(transformers, arch, None)
        if model_module is None:
            if "AutoModel" not in auto_map:
                raise ValueError(
                    f"Cannot find model module. '{arch}' is not a registered "
                    "model in the Transformers library (only relevant if the "
                    "model is meant to be in Transformers) and 'AutoModel' is "
                    "not present in the model config's 'auto_map' (relevant "
                    "if the model is custom).")
            model_module = auto_modules["AutoModel"]
66
67
68
        # TODO(Isotr0py): Further clean up these raises.
        # perhaps handled them in _ModelRegistry._raise_for_unsupported?
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
69
            if not model_module.is_backend_compatible():
70
71
72
                raise ValueError(
                    f"The Transformers implementation of {arch} is not "
                    "compatible with vLLM.")
73
            architectures[i] = "TransformersForCausalLM"
74
        if model_config.model_impl == ModelImpl.AUTO:
75
            if not model_module.is_backend_compatible():
76
77
                raise ValueError(
                    f"{arch} has no vLLM implementation and the Transformers "
78
79
                    "implementation is not compatible with vLLM. Try setting "
                    "VLLM_USE_V1=0.")
80
81
82
83
            logger.warning(
                "%s has no vLLM implementation, falling back to Transformers "
                "implementation. Some features may not be supported and "
                "performance may not be optimal.", arch)
84
            architectures[i] = "TransformersForCausalLM"
85
86
87
    return architectures


88
89
90
def get_model_architecture(
        model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
    architectures = getattr(model_config.hf_config, "architectures", [])
zhuwenwen's avatar
zhuwenwen committed
91
    visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
92
    support_nn_architectures = ['LlamaForCausalLM', 'Qwen2ForCausalLM', 'QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 
zhuwenwen's avatar
zhuwenwen committed
93
                                'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']  
94
    if any(arch in architectures for arch in support_nn_architectures): 
zhuwenwen's avatar
zhuwenwen committed
95
        if os.getenv('LLAMA_NN') != '0': 
zhuwenwen's avatar
zhuwenwen committed
96
             if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
zhuwenwen's avatar
zhuwenwen committed
97
98
99
                os.environ['LLAMA_NN'] = '0'
             else:
                os.environ['LLAMA_NN'] = '1'
zhuwenwen's avatar
zhuwenwen committed
100
        if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
zhuwenwen's avatar
zhuwenwen committed
101
            os.environ['LM_NN'] = '0'
zhuwenwen's avatar
zhuwenwen committed
102
        else:
zhuwenwen's avatar
zhuwenwen committed
103
            os.environ['LM_NN'] = '1'
104
105
        if os.getenv('GEMM_PAD') != '1': 
            os.environ['GEMM_PAD'] = '0'
zhuwenwen's avatar
zhuwenwen committed
106
107
        if os.getenv('FA_PAD') != '1': 
            os.environ['FA_PAD'] = '0'
108
        # awq相关配置
zhuwenwen's avatar
zhuwenwen committed
109
        try:
110
111
112
            if os.getenv('AWQ_MOE_SZ') == None:
                os.environ['AWQ_MOE_SZ'] = '1'
            if os.getenv('AWQ_PAD') == None and (torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120):
zhuwenwen's avatar
zhuwenwen committed
113
114
115
116
117
118
                os.environ['AWQ_PAD'] = '1'
        except Exception as e:
            if os.getenv('AWQ_PAD') != '0': 
                os.environ['AWQ_PAD'] = '1'
            else:
                os.environ['AWQ_PAD'] = '0'
zhuwenwen's avatar
zhuwenwen committed
119
120
    else:
        os.environ['LLAMA_NN'] = '0'
zhuwenwen's avatar
zhuwenwen committed
121
        os.environ['LM_NN'] = '0'
122
123
        os.environ['GEMM_PAD'] = '0'
        os.environ['FA_PAD'] = '0'
zhuwenwen's avatar
zhuwenwen committed
124
        os.environ['AWQ_PAD'] = '0'
125
        
126
127
    # Special handling for quantized Mixtral.
    # FIXME(woosuk): This is a temporary hack.
128
129
130
    mixtral_supported = [
        "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
    ]
131

132
    if (model_config.quantization is not None
133
            and model_config.quantization not in mixtral_supported
134
135
            and "MixtralForCausalLM" in architectures):
        architectures = ["QuantMixtralForCausalLM"]
136

137
    vllm_supported_archs = ModelRegistry.get_supported_archs()
138
139
140
141
    vllm_not_supported = not any(arch in vllm_supported_archs
                                 for arch in architectures)
    if (model_config.model_impl == ModelImpl.TRANSFORMERS or
            model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
142
        architectures = resolve_transformers_arch(model_config, architectures)
143

144
    model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
145
    if model_config.task == "embed":
146
        model_cls = as_embedding_model(model_cls)
147
148
149
150
    elif model_config.task == "classify":
        model_cls = as_classification_model(model_cls)
    elif model_config.task == "reward":
        model_cls = as_reward_model(model_cls)
151
152

    return model_cls, arch
153
154
155
156


def get_architecture_class_name(model_config: ModelConfig) -> str:
    return get_model_architecture(model_config)[1]
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179


@dataclass
class ParamMapping:
    """
    A class to handle parameter mapping for model weight loading.
    It creates a bidirectional mapping between packed parameters and their 
    constituent parts.
    """
    packed_mapping: Dict[str, List[str]]
    inverse_packed_mapping: Dict[str, Tuple[str,
                                            int]] = field(default_factory=dict)

    def __post_init__(self):
        for packed_name, sub_params in self.packed_mapping.items():
            # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
            if len(sub_params) == 1 and sub_params[0] == packed_name:
                continue
            for index, param_name in enumerate(sub_params):
                self.inverse_packed_mapping[param_name] = (
                    packed_name,
                    index,
                )
180
181
182
183
184
185
186

    def get_sub_modules(self,
                        module_name: str) -> Optional[Tuple[str, List[str]]]:
        for key, value in self.packed_mapping.items():
            if module_name.endswith(key):
                return key, value
        return None
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206


def configure_quant_config(quant_config: QuantizationConfig,
                           model_class: Type[nn.Module]):
    """
    Pass packed_modules_mapping by reference to quant_config so that
    quant_config can properly match fused modules

    Note that model attributes are passed by reference to quant_config,
    enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
    """
    packed_mapping = getattr(model_class, "packed_modules_mapping", None)
    if packed_mapping is not None:
        # pass packed_modules_mapping by reference to quant_config
        quant_config.packed_modules_mapping = packed_mapping
    else:
        logger.warning(
            "The model class %s has not defined `packed_modules_mapping`, "
            "this may lead to incorrect mapping of quantized or ignored "
            "modules", model_class.__name__)