mistral.py 9.37 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

Patrick von Platen's avatar
Patrick von Platen committed
5
from transformers import PretrainedConfig, WhisperConfig
6
7
8
9
10
11

from vllm.logger import init_logger

logger = init_logger(__name__)


12
13
14
15
def adapt_config_dict(
    config_dict: dict[str, Any],
    defaults: dict[str, Any],
) -> PretrainedConfig:
16
17
18
19
20
    config_dict = _remap_general_mistral_args(config_dict)

    if bool(config_dict.get("quantization")):
        config_dict = _remap_mistral_quantization_args(config_dict)

21
22
23
24
    is_moe = bool(config_dict.get("moe"))
    is_mistral_large_3 = (
        is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0
    )
25
26
    if config_dict.get("model_type") == "mamba":
        config_dict["architectures"] = ["Mamba2ForCausalLM"]
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
    elif is_moe and is_mistral_large_3:
        config_dict = _remap_moe_args(config_dict)
        config_dict["model_type"] = "deepseek_v3"
        config_dict["architectures"] = ["MistralLarge3ForCausalLM"]

        assert "llama_4_scaling" in config_dict, (
            "MistralLarge3 expect llama4 scaling config."
        )
        llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
        assert all(
            [
                key in config_dict["llama_4_scaling"]
                for key in llama_4_scaling_config_keys
            ]
        ), (
            "llama_4_scaling config should define the keys: "
            f"{','.join(llama_4_scaling_config_keys)}"
        )
    elif is_moe:
46
47
48
49
50
51
        config_dict["architectures"] = ["MixtralForCausalLM"]
    else:
        config_dict["architectures"] = ["MistralForCausalLM"]

    if bool(config_dict.get("yarn")):
        config_dict = _remap_mistral_yarn_args(config_dict)
Patrick von Platen's avatar
Patrick von Platen committed
52

53
54
55
56
57
58
59
60
61
62
63
64
    if bool(config_dict.get("llama_4_scaling")):
        llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
        assert all(
            [
                key in config_dict["llama_4_scaling"]
                for key in llama_4_scaling_config_keys
            ]
        ), (
            "llama_4_scaling config should define the keys: "
            f"{','.join(llama_4_scaling_config_keys)}"
        )

65
66
67
    is_vision = (config_dict.get("multimodal") or {}).get(
        "vision_encoder_args"
    ) or config_dict.get("vision_encoder")
Patrick von Platen's avatar
Patrick von Platen committed
68
    is_audio = bool(
69
70
71
72
        ((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get(
            "encoder_args"
        )
    )
Patrick von Platen's avatar
Patrick von Platen committed
73

74
    assert not (is_vision and is_audio), "Vision and audio are mutually exclusive"
Patrick von Platen's avatar
Patrick von Platen committed
75
76

    if is_vision:
77
        config_dict = _remap_mistral_vision_args(config_dict)
Patrick von Platen's avatar
Patrick von Platen committed
78
79
    if is_audio:
        config_dict = _remap_mistral_audio_args(config_dict)
80

81
82
83
    for k, v in defaults.items():
        config_dict.setdefault(k, v)

84
85
    config = PretrainedConfig.from_dict(config_dict)

86
    logger.debug("Initialized config %s", config)
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    return config


def _remap_mistral_vision_args(config: dict) -> dict:
    if config.get("multimodal"):
        vision_config = config.pop("multimodal")
    else:
        vision_config = config.pop("vision_encoder")

    quant_config = config.get("quantization_config")
    config = {
        "model_type": "pixtral",
        "architectures": ["PixtralForConditionalGeneration"],
        "text_config": PretrainedConfig.from_dict(config),
        "vision_config": PretrainedConfig.from_dict(vision_config),
    }
    if quant_config:
        config["quantization_config"] = quant_config
    return config


def _remap_mistral_yarn_args(config: dict) -> dict:
110
111
112
    yarn_config_map = {
        "factor": "factor",
        "original_max_position_embeddings": "original_max_position_embeddings",
113
114
        "beta": "beta_fast",
        "alpha": "beta_slow",
115
        "apply_scale": "apply_yarn_scaling",
116
117
    }
    yarn_config = config.get("yarn") or {}
118
    config["rope_parameters"] = {
119
        "rope_type": "yarn",
120
        "mscale_all_dim": 1,
121
    }
Julien Denize's avatar
Julien Denize committed
122
123
124
125

    if rope_theta := config.pop("rope_theta", None):
        config["rope_parameters"]["rope_theta"] = rope_theta

126
127
    for old_name, new_name in yarn_config_map.items():
        if old_name in yarn_config:
128
            config["rope_parameters"][new_name] = yarn_config.pop(old_name)
129
130
131

    assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
    return config


def _remap_general_mistral_args(config: dict) -> dict:
    # Mistral key -> HF key
    config_mapping = {
        "dim": "hidden_size",
        "norm_eps": "rms_norm_eps",
        "n_kv_heads": "num_key_value_heads",
        "n_layers": "num_hidden_layers",
        "n_heads": "num_attention_heads",
        "hidden_dim": "intermediate_size",
    }
    # HF key -> (Mistral key, default value)
    top_level_mapping_with_default = {
        "model_type": ("model_type", "transformer"),
        "hidden_act": ("activation", "silu"),
        "tie_word_embeddings": ("tied_embeddings", False),
150
        "max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
151
152
153
154
155
156
157
        "max_position_embeddings": ("max_position_embeddings", 128_000),
    }

    for key, new_key in config_mapping.items():
        if key in config:
            config[new_key] = config.pop(key)

158
    for new_key, (key, default_value) in top_level_mapping_with_default.items():
159
160
161
162
163
164
        config[new_key] = config.pop(key, default_value)

    return config


def _remap_mistral_quantization_args(config: dict) -> dict:
165
166
167
168
169
170
171
172
173
174
175
176
177
178
    if config.get("quantization"):
        quantization = config.pop("quantization", {})
        if quantization.get("qformat_weight") == "fp8_e4m3":
            qscheme_act = quantization.get("qscheme_act")
            assert qscheme_act in ("NO_SCALES", "TENSOR", None), (
                "Only NO_SCALES and TENSOR (default) are supported for qscheme_act"
            )
            is_dynamic = qscheme_act == "NO_SCALES"
            config["quantization_config"] = {
                "quant_method": "fp8",
                "activation_scheme": "dynamic" if is_dynamic else "static",
            }
        else:
            raise ValueError(f"Found unknown quantization='{quantization}' in config")
179
180

    return config
Patrick von Platen's avatar
Patrick von Platen committed
181
182
183
184
185
186


def _remap_mistral_audio_args(config: dict) -> dict:
    whisper_args = config["multimodal"].pop("whisper_model_args")
    encoder_args = whisper_args["encoder_args"]
    downsample_args = whisper_args["downsample_args"]
Patrick von Platen's avatar
Patrick von Platen committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
    downsample_factor = downsample_args["downsample_factor"]

    # make sure that k/v blocks can be allocated with
    # unified k/v cache class and pool whisper k/v cache blocks
    # with downsample_factor:1 ratio
    if encoder_args.get("causal"):
        block_pool_size = downsample_factor
        config["projection_size"] = downsample_factor * encoder_args["dim"]
    else:
        block_pool_size = 1

    _maybe_sliding_window = encoder_args.get("ragged_attention", None)
    if _maybe_sliding_window is None:
        sliding_window = None
    elif _maybe_sliding_window.isdigit():
        sliding_window = int(_maybe_sliding_window)
    else:
        raise NotImplementedError(f"Unsupported: {_maybe_sliding_window=}")

    architecture = (
        "VoxtralStreamingGeneration"
        if encoder_args.get("causal")
        else "VoxtralForConditionalGeneration"
    )
Patrick von Platen's avatar
Patrick von Platen committed
211
212
213

    quant_config = config.get("quantization_config")
    config = {
Patrick von Platen's avatar
Patrick von Platen committed
214
215
        "model_type": "voxtral",
        "architectures": [architecture],
216
217
        "text_config": PretrainedConfig.from_dict(config),
        "audio_config": WhisperConfig(
Patrick von Platen's avatar
Patrick von Platen committed
218
219
220
221
            num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
            window_size=encoder_args["audio_encoding_args"]["window_size"],
            sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
            hop_length=encoder_args["audio_encoding_args"]["hop_length"],
Patrick von Platen's avatar
Patrick von Platen committed
222
            downsample_factor=downsample_factor,
Patrick von Platen's avatar
Patrick von Platen committed
223
224
225
226
            d_model=encoder_args["dim"],
            encoder_layers=encoder_args["n_layers"],
            encoder_ffn_dim=encoder_args["hidden_dim"],
            encoder_attention_heads=encoder_args["n_heads"],
227
            encoder_head_dim=encoder_args["head_dim"],
Patrick von Platen's avatar
Patrick von Platen committed
228
229
            vocab_size=encoder_args["vocab_size"],
            max_source_positions=encoder_args["max_source_positions"],
230
            is_encoder_decoder=False,  # Override WhisperConfig default
Patrick von Platen's avatar
Patrick von Platen committed
231
232
233
234
            is_causal=encoder_args.get("causal", False),
            sliding_window=sliding_window,
            block_pool_size=block_pool_size,
            pos_embed=encoder_args.get("pos_embed", "sinusoidal"),
235
236
            # only needed for RoPE
            max_position_embeddings=block_pool_size * config["max_position_embeddings"],
237
        ),
Patrick von Platen's avatar
Patrick von Platen committed
238
239
240
241
    }
    if quant_config:
        config["quantization_config"] = quant_config
    return config
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266


def _remap_moe_args(config: dict) -> dict:
    moe_config_map = {
        "route_every_n": "moe_layer_freq",
        "first_k_dense_replace": "first_k_dense_replace",
        "num_experts_per_tok": "num_experts_per_tok",
        "num_experts": "n_routed_experts",
        "expert_hidden_dim": "moe_intermediate_size",
        "routed_scale": "routed_scaling_factor",
        "num_shared_experts": "n_shared_experts",
        "num_expert_groups": "n_group",
        "num_expert_groups_per_tok": "topk_group",
    }
    moe_config = config.get("moe", {})
    for old_name, new_name in moe_config_map.items():
        if old_name in moe_config:
            value = moe_config.pop(old_name)
            config[new_name] = value

    config["topk_method"] = None
    config["norm_topk_prob"] = True
    config["scoring_func"] = "softmax"

    return config