utils.py 8.52 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Utilities for selecting and loading models."""
import contextlib
4
from dataclasses import dataclass, field
5
from typing import Dict, List, Optional, Tuple, Type
6

zhuwenwen's avatar
zhuwenwen committed
7
import os
8
import torch
9
import transformers
10
from torch import nn
11
from transformers.dynamic_module_utils import get_class_from_dynamic_module
12

13
14
from vllm.config import ModelConfig, ModelImpl
from vllm.logger import init_logger
15
16
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
17
from vllm.model_executor.models import ModelRegistry
18
19
20
from vllm.model_executor.models.adapters import (as_classification_model,
                                                 as_embedding_model,
                                                 as_reward_model)
21

22
23
logger = init_logger(__name__)

24
25
26
27
28
29
30
31
32
33

@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def is_transformers_impl_compatible(
        arch: str,
        module: Optional[transformers.PreTrainedModel] = None) -> bool:
    mod = module or getattr(transformers, arch, None)
    if mod is None:
        return False
    if hasattr(mod, "supports_backend"):
        return mod.is_backend_compatible()
    else:
        return mod._supports_flex_attn


def resolve_transformers_fallback(model_config: ModelConfig,
                                  architectures: list[str]):
    for i, arch in enumerate(architectures):
        if arch == "TransformersModel":
            continue
        custom_module = None
        auto_map = getattr(model_config.hf_config, "auto_map", None)
        if auto_map is not None and "AutoModel" in auto_map:
            custom_module = get_class_from_dynamic_module(
                model_config.hf_config.auto_map["AutoModel"],
                model_config.model)
        # TODO(Isotr0py): Further clean up these raises.
        # perhaps handled them in _ModelRegistry._raise_for_unsupported?
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            if not is_transformers_impl_compatible(arch, custom_module):
                raise ValueError(
                    f"The Transformers implementation of {arch} is not "
                    "compatible with vLLM.")
            architectures[i] = "TransformersModel"
        if model_config.model_impl == ModelImpl.AUTO:
            if not is_transformers_impl_compatible(arch, custom_module):
                raise ValueError(
                    f"{arch} has no vLLM implementation and the Transformers "
                    "implementation is not compatible with vLLM.")
            logger.warning(
                "%s has no vLLM implementation, falling back to Transformers "
                "implementation. Some features may not be supported and "
                "performance may not be optimal.", arch)
            architectures[i] = "TransformersModel"
    return architectures


78
79
80
def get_model_architecture(
        model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
    architectures = getattr(model_config.hf_config, "architectures", [])
zhuwenwen's avatar
zhuwenwen committed
81
    visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
zhuwenwen's avatar
zhuwenwen committed
82
    # TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
83
    support_nn_architectures = ['LlamaForCausalLM', 'Qwen2ForCausalLM', 'QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 
王敏's avatar
王敏 committed
84
85
86
                                'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 
                                'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 
                                'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 
87
                                'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']  
88
    if any(arch in architectures for arch in support_nn_architectures): 
zhuwenwen's avatar
zhuwenwen committed
89
        if os.getenv('LLAMA_NN') != '0': 
zhuwenwen's avatar
zhuwenwen committed
90
             if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
zhuwenwen's avatar
zhuwenwen committed
91
92
93
                os.environ['LLAMA_NN'] = '0'
             else:
                os.environ['LLAMA_NN'] = '1'
zhuwenwen's avatar
zhuwenwen committed
94
        if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
zhuwenwen's avatar
zhuwenwen committed
95
            os.environ['LM_NN'] = '0'
zhuwenwen's avatar
zhuwenwen committed
96
        else:
zhuwenwen's avatar
zhuwenwen committed
97
            os.environ['LM_NN'] = '1'
98
99
        if os.getenv('GEMM_PAD') != '1': 
            os.environ['GEMM_PAD'] = '0'
zhuwenwen's avatar
zhuwenwen committed
100
101
        if os.getenv('FA_PAD') != '1': 
            os.environ['FA_PAD'] = '0'
102
        # awq相关配置
zhuwenwen's avatar
zhuwenwen committed
103
        try:
104
105
            if os.getenv('AWQ_MOE_SZ') == None:
                os.environ['AWQ_MOE_SZ'] = '1'
106
            if os.getenv('AWQ_PAD') == None and (torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120):
zhuwenwen's avatar
zhuwenwen committed
107
108
109
110
111
112
                os.environ['AWQ_PAD'] = '1'
        except Exception as e:
            if os.getenv('AWQ_PAD') != '0': 
                os.environ['AWQ_PAD'] = '1'
            else:
                os.environ['AWQ_PAD'] = '0'
zhuwenwen's avatar
zhuwenwen committed
113
114
    else:
        os.environ['LLAMA_NN'] = '0'
zhuwenwen's avatar
zhuwenwen committed
115
        os.environ['LM_NN'] = '0'
116
117
        os.environ['GEMM_PAD'] = '0'
        os.environ['FA_PAD'] = '0'
zhuwenwen's avatar
zhuwenwen committed
118
        os.environ['AWQ_PAD'] = '0'
119
        
120
121
    # Special handling for quantized Mixtral.
    # FIXME(woosuk): This is a temporary hack.
122
123
124
    mixtral_supported = [
        "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
    ]
125

126
    if (model_config.quantization is not None
127
            and model_config.quantization not in mixtral_supported
128
129
            and "MixtralForCausalLM" in architectures):
        architectures = ["QuantMixtralForCausalLM"]
130

131
132
133
134
135
136
137
138
    vllm_supported_archs = ModelRegistry.get_supported_archs()
    is_vllm_supported = any(arch in vllm_supported_archs
                            for arch in architectures)
    if (not is_vllm_supported
            or model_config.model_impl == ModelImpl.TRANSFORMERS):
        architectures = resolve_transformers_fallback(model_config,
                                                      architectures)

139
    model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
140
    if model_config.task == "embed":
141
        model_cls = as_embedding_model(model_cls)
142
143
144
145
    elif model_config.task == "classify":
        model_cls = as_classification_model(model_cls)
    elif model_config.task == "reward":
        model_cls = as_reward_model(model_cls)
146
147

    return model_cls, arch
148
149
150
151


def get_architecture_class_name(model_config: ModelConfig) -> str:
    return get_model_architecture(model_config)[1]
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174


@dataclass
class ParamMapping:
    """
    A class to handle parameter mapping for model weight loading.
    It creates a bidirectional mapping between packed parameters and their 
    constituent parts.
    """
    packed_mapping: Dict[str, List[str]]
    inverse_packed_mapping: Dict[str, Tuple[str,
                                            int]] = field(default_factory=dict)

    def __post_init__(self):
        for packed_name, sub_params in self.packed_mapping.items():
            # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
            if len(sub_params) == 1 and sub_params[0] == packed_name:
                continue
            for index, param_name in enumerate(sub_params):
                self.inverse_packed_mapping[param_name] = (
                    packed_name,
                    index,
                )
175
176
177
178
179
180
181

    def get_sub_modules(self,
                        module_name: str) -> Optional[Tuple[str, List[str]]]:
        for key, value in self.packed_mapping.items():
            if module_name.endswith(key):
                return key, value
        return None
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201


def configure_quant_config(quant_config: QuantizationConfig,
                           model_class: Type[nn.Module]):
    """
    Pass packed_modules_mapping by reference to quant_config so that
    quant_config can properly match fused modules

    Note that model attributes are passed by reference to quant_config,
    enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
    """
    packed_mapping = getattr(model_class, "packed_modules_mapping", None)
    if packed_mapping is not None:
        # pass packed_modules_mapping by reference to quant_config
        quant_config.packed_modules_mapping = packed_mapping
    else:
        logger.warning(
            "The model class %s has not defined `packed_modules_mapping`, "
            "this may lead to incorrect mapping of quantized or ignored "
            "modules", model_class.__name__)