utils.py 5.42 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
3
"""Utilities for selecting and loading models."""
import contextlib
4
from dataclasses import dataclass, field
5
from typing import Dict, List, Optional, Tuple, Type
6
7

import torch
8
import transformers
9
from torch import nn
10
from transformers.dynamic_module_utils import get_class_from_dynamic_module
11

12
13
from vllm.config import ModelConfig, ModelImpl
from vllm.logger import init_logger
14
from vllm.model_executor.models import ModelRegistry
15
16
17
from vllm.model_executor.models.adapters import (as_classification_model,
                                                 as_embedding_model,
                                                 as_reward_model)
18

19
20
logger = init_logger(__name__)

21
22
23
24
25
26
27
28
29
30

@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    old_dtype = torch.get_default_dtype()
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(old_dtype)


31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def is_transformers_impl_compatible(
        arch: str,
        module: Optional[transformers.PreTrainedModel] = None) -> bool:
    mod = module or getattr(transformers, arch, None)
    if mod is None:
        return False
    if hasattr(mod, "supports_backend"):
        return mod.is_backend_compatible()
    else:
        return mod._supports_flex_attn


def resolve_transformers_fallback(model_config: ModelConfig,
                                  architectures: list[str]):
    for i, arch in enumerate(architectures):
        if arch == "TransformersModel":
            continue
        custom_module = None
        auto_map = getattr(model_config.hf_config, "auto_map", None)
        if auto_map is not None and "AutoModel" in auto_map:
            custom_module = get_class_from_dynamic_module(
                model_config.hf_config.auto_map["AutoModel"],
                model_config.model)
        # TODO(Isotr0py): Further clean up these raises.
        # perhaps handled them in _ModelRegistry._raise_for_unsupported?
        if model_config.model_impl == ModelImpl.TRANSFORMERS:
            if not is_transformers_impl_compatible(arch, custom_module):
                raise ValueError(
                    f"The Transformers implementation of {arch} is not "
                    "compatible with vLLM.")
            architectures[i] = "TransformersModel"
        if model_config.model_impl == ModelImpl.AUTO:
            if not is_transformers_impl_compatible(arch, custom_module):
                raise ValueError(
                    f"{arch} has no vLLM implementation and the Transformers "
                    "implementation is not compatible with vLLM.")
            logger.warning(
                "%s has no vLLM implementation, falling back to Transformers "
                "implementation. Some features may not be supported and "
                "performance may not be optimal.", arch)
            architectures[i] = "TransformersModel"
    return architectures


75
def get_model_architecture(
76
77
        model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
    architectures = getattr(model_config.hf_config, "architectures", [])
78

79
80
    # Special handling for quantized Mixtral.
    # FIXME(woosuk): This is a temporary hack.
81
82
83
    mixtral_supported = [
        "fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
    ]
84

85
    if (model_config.quantization is not None
86
            and model_config.quantization not in mixtral_supported
87
88
            and "MixtralForCausalLM" in architectures):
        architectures = ["QuantMixtralForCausalLM"]
89

90
91
92
93
94
95
96
97
    vllm_supported_archs = ModelRegistry.get_supported_archs()
    is_vllm_supported = any(arch in vllm_supported_archs
                            for arch in architectures)
    if (not is_vllm_supported
            or model_config.model_impl == ModelImpl.TRANSFORMERS):
        architectures = resolve_transformers_fallback(model_config,
                                                      architectures)

98
    model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
99
    if model_config.task == "embed":
100
        model_cls = as_embedding_model(model_cls)
101
102
103
104
    elif model_config.task == "classify":
        model_cls = as_classification_model(model_cls)
    elif model_config.task == "reward":
        model_cls = as_reward_model(model_cls)
105
106

    return model_cls, arch
107
108
109
110


def get_architecture_class_name(model_config: ModelConfig) -> str:
    return get_model_architecture(model_config)[1]
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133


@dataclass
class ParamMapping:
    """
    A class to handle parameter mapping for model weight loading.
    It creates a bidirectional mapping between packed parameters and their 
    constituent parts.
    """
    packed_mapping: Dict[str, List[str]]
    inverse_packed_mapping: Dict[str, Tuple[str,
                                            int]] = field(default_factory=dict)

    def __post_init__(self):
        for packed_name, sub_params in self.packed_mapping.items():
            # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
            if len(sub_params) == 1 and sub_params[0] == packed_name:
                continue
            for index, param_name in enumerate(sub_params):
                self.inverse_packed_mapping[param_name] = (
                    packed_name,
                    index,
                )
134
135
136
137
138
139
140

    def get_sub_modules(self,
                        module_name: str) -> Optional[Tuple[str, List[str]]]:
        for key, value in self.packed_mapping.items():
            if module_name.endswith(key):
                return key, value
        return None