"tests/entrypoints/openai/models/test_models.py" did not exist on "e1957c6ebdd4860f832c26ae4de4195d10803723"
mistral.py 6.2 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

Patrick von Platen's avatar
Patrick von Platen committed
5
from transformers import PretrainedConfig, WhisperConfig
6
7
8
9
10
11

from vllm.logger import init_logger

logger = init_logger(__name__)


12
def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig:
13
14
15
16
17
18
19
20
21
22
23
24
25
    config_dict.update(kwargs)
    config_dict = _remap_general_mistral_args(config_dict)

    if bool(config_dict.get("quantization")):
        config_dict = _remap_mistral_quantization_args(config_dict)

    if bool(config_dict.get("moe")):
        config_dict["architectures"] = ["MixtralForCausalLM"]
    else:
        config_dict["architectures"] = ["MistralForCausalLM"]

    if bool(config_dict.get("yarn")):
        config_dict = _remap_mistral_yarn_args(config_dict)
Patrick von Platen's avatar
Patrick von Platen committed
26

27
28
29
30
31
32
33
34
35
36
37
38
    if bool(config_dict.get("llama_4_scaling")):
        llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
        assert all(
            [
                key in config_dict["llama_4_scaling"]
                for key in llama_4_scaling_config_keys
            ]
        ), (
            "llama_4_scaling config should define the keys: "
            f"{','.join(llama_4_scaling_config_keys)}"
        )

39
40
41
    is_vision = (config_dict.get("multimodal") or {}).get(
        "vision_encoder_args"
    ) or config_dict.get("vision_encoder")
Patrick von Platen's avatar
Patrick von Platen committed
42
    is_audio = bool(
43
44
45
46
        ((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get(
            "encoder_args"
        )
    )
Patrick von Platen's avatar
Patrick von Platen committed
47

48
    assert not (is_vision and is_audio), "Vision and audio are mutually exclusive"
Patrick von Platen's avatar
Patrick von Platen committed
49
50

    if is_vision:
51
        config_dict = _remap_mistral_vision_args(config_dict)
Patrick von Platen's avatar
Patrick von Platen committed
52
53
    if is_audio:
        config_dict = _remap_mistral_audio_args(config_dict)
54
55
56

    config = PretrainedConfig.from_dict(config_dict)

57
    logger.debug("Initialized config %s", config)
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

    return config


def _remap_mistral_vision_args(config: dict) -> dict:
    if config.get("multimodal"):
        vision_config = config.pop("multimodal")
    else:
        vision_config = config.pop("vision_encoder")

    quant_config = config.get("quantization_config")
    config = {
        "model_type": "pixtral",
        "architectures": ["PixtralForConditionalGeneration"],
        "text_config": PretrainedConfig.from_dict(config),
        "vision_config": PretrainedConfig.from_dict(vision_config),
    }
    if quant_config:
        config["quantization_config"] = quant_config
    return config


def _remap_mistral_yarn_args(config: dict) -> dict:
81
82
83
    yarn_config_map = {
        "factor": "factor",
        "original_max_position_embeddings": "original_max_position_embeddings",
84
85
        "beta": "beta_fast",
        "alpha": "beta_slow",
86
        "apply_scale": "apply_yarn_scaling",
87
88
    }
    yarn_config = config.get("yarn") or {}
89
    config["rope_parameters"] = {
90
        "rope_type": "yarn",
91
        "mscale_all_dim": 1,
92
    }
Julien Denize's avatar
Julien Denize committed
93
94
95
96

    if rope_theta := config.pop("rope_theta", None):
        config["rope_parameters"]["rope_theta"] = rope_theta

97
98
    for old_name, new_name in yarn_config_map.items():
        if old_name in yarn_config:
99
            config["rope_parameters"][new_name] = yarn_config.pop(old_name)
100
101
102

    assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    return config


def _remap_general_mistral_args(config: dict) -> dict:
    # Mistral key -> HF key
    config_mapping = {
        "dim": "hidden_size",
        "norm_eps": "rms_norm_eps",
        "n_kv_heads": "num_key_value_heads",
        "n_layers": "num_hidden_layers",
        "n_heads": "num_attention_heads",
        "hidden_dim": "intermediate_size",
    }
    # HF key -> (Mistral key, default value)
    top_level_mapping_with_default = {
        "model_type": ("model_type", "transformer"),
        "hidden_act": ("activation", "silu"),
        "tie_word_embeddings": ("tied_embeddings", False),
        "max_seq_len": ("max_seq_len", 128_000),
        "max_position_embeddings": ("max_position_embeddings", 128_000),
    }

    for key, new_key in config_mapping.items():
        if key in config:
            config[new_key] = config.pop(key)

129
    for new_key, (key, default_value) in top_level_mapping_with_default.items():
130
131
132
133
134
135
136
137
138
        config[new_key] = config.pop(key, default_value)

    return config


def _remap_mistral_quantization_args(config: dict) -> dict:
    quantization = config.get("quantization", {})
    if quantization.get("qformat_weight") == "fp8_e4m3":
        # This maps to the FP8 static per-tensor quantization scheme
139
        quantization_config = {"quant_method": "fp8", "activation_scheme": "static"}
140
141
142
143
    elif quantization.get("quant_method") == "compressed-tensors":
        # Pass through the quantization config to compressed-tensors
        quantization_config = quantization
    else:
144
        raise ValueError(f"Found unknown quantization='{quantization}' in config")
145
146
147
148

    config["quantization_config"] = quantization_config

    return config
Patrick von Platen's avatar
Patrick von Platen committed
149
150
151
152
153
154
155
156
157


def _remap_mistral_audio_args(config: dict) -> dict:
    whisper_args = config["multimodal"].pop("whisper_model_args")
    encoder_args = whisper_args["encoder_args"]
    downsample_args = whisper_args["downsample_args"]

    quant_config = config.get("quantization_config")
    config = {
158
        "model_type": "whixtral",
Patrick von Platen's avatar
Patrick von Platen committed
159
        "architectures": ["VoxtralForConditionalGeneration"],
160
161
        "text_config": PretrainedConfig.from_dict(config),
        "audio_config": WhisperConfig(
Patrick von Platen's avatar
Patrick von Platen committed
162
163
164
165
166
167
168
169
170
171
172
            num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
            window_size=encoder_args["audio_encoding_args"]["window_size"],
            sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
            hop_length=encoder_args["audio_encoding_args"]["hop_length"],
            downsample_factor=downsample_args["downsample_factor"],
            d_model=encoder_args["dim"],
            encoder_layers=encoder_args["n_layers"],
            encoder_ffn_dim=encoder_args["hidden_dim"],
            encoder_attention_heads=encoder_args["n_heads"],
            vocab_size=encoder_args["vocab_size"],
            max_source_positions=encoder_args["max_source_positions"],
173
            is_encoder_decoder=False,  # Override WhisperConfig default
174
        ),
Patrick von Platen's avatar
Patrick von Platen committed
175
176
177
178
    }
    if quant_config:
        config["quantization_config"] = quant_config
    return config