mistral.py 6.41 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

Patrick von Platen's avatar
Patrick von Platen committed
5
from transformers import PretrainedConfig, WhisperConfig
6
7
8
9
10
11

from vllm.logger import init_logger

logger = init_logger(__name__)


12
13
14
15
def adapt_config_dict(
    config_dict: dict[str, Any],
    defaults: dict[str, Any],
) -> PretrainedConfig:
16
17
18
19
20
    config_dict = _remap_general_mistral_args(config_dict)

    if bool(config_dict.get("quantization")):
        config_dict = _remap_mistral_quantization_args(config_dict)

21
22
23
    if config_dict.get("model_type") == "mamba":
        config_dict["architectures"] = ["Mamba2ForCausalLM"]
    elif bool(config_dict.get("moe")):
24
25
26
27
28
29
        config_dict["architectures"] = ["MixtralForCausalLM"]
    else:
        config_dict["architectures"] = ["MistralForCausalLM"]

    if bool(config_dict.get("yarn")):
        config_dict = _remap_mistral_yarn_args(config_dict)
Patrick von Platen's avatar
Patrick von Platen committed
30

31
32
33
34
35
36
37
38
39
40
41
42
    if bool(config_dict.get("llama_4_scaling")):
        llama_4_scaling_config_keys = ["original_max_position_embeddings", "beta"]
        assert all(
            [
                key in config_dict["llama_4_scaling"]
                for key in llama_4_scaling_config_keys
            ]
        ), (
            "llama_4_scaling config should define the keys: "
            f"{','.join(llama_4_scaling_config_keys)}"
        )

43
44
45
    is_vision = (config_dict.get("multimodal") or {}).get(
        "vision_encoder_args"
    ) or config_dict.get("vision_encoder")
Patrick von Platen's avatar
Patrick von Platen committed
46
    is_audio = bool(
47
48
49
50
        ((config_dict.get("multimodal") or {}).get("whisper_model_args") or {}).get(
            "encoder_args"
        )
    )
Patrick von Platen's avatar
Patrick von Platen committed
51

52
    assert not (is_vision and is_audio), "Vision and audio are mutually exclusive"
Patrick von Platen's avatar
Patrick von Platen committed
53
54

    if is_vision:
55
        config_dict = _remap_mistral_vision_args(config_dict)
Patrick von Platen's avatar
Patrick von Platen committed
56
57
    if is_audio:
        config_dict = _remap_mistral_audio_args(config_dict)
58

59
60
61
    for k, v in defaults.items():
        config_dict.setdefault(k, v)

62
63
    config = PretrainedConfig.from_dict(config_dict)

64
    logger.debug("Initialized config %s", config)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87

    return config


def _remap_mistral_vision_args(config: dict) -> dict:
    if config.get("multimodal"):
        vision_config = config.pop("multimodal")
    else:
        vision_config = config.pop("vision_encoder")

    quant_config = config.get("quantization_config")
    config = {
        "model_type": "pixtral",
        "architectures": ["PixtralForConditionalGeneration"],
        "text_config": PretrainedConfig.from_dict(config),
        "vision_config": PretrainedConfig.from_dict(vision_config),
    }
    if quant_config:
        config["quantization_config"] = quant_config
    return config


def _remap_mistral_yarn_args(config: dict) -> dict:
88
89
90
    yarn_config_map = {
        "factor": "factor",
        "original_max_position_embeddings": "original_max_position_embeddings",
91
92
        "beta": "beta_fast",
        "alpha": "beta_slow",
93
        "apply_scale": "apply_yarn_scaling",
94
95
    }
    yarn_config = config.get("yarn") or {}
96
    config["rope_parameters"] = {
97
        "rope_type": "yarn",
98
        "mscale_all_dim": 1,
99
    }
Julien Denize's avatar
Julien Denize committed
100
101
102
103

    if rope_theta := config.pop("rope_theta", None):
        config["rope_parameters"]["rope_theta"] = rope_theta

104
105
    for old_name, new_name in yarn_config_map.items():
        if old_name in yarn_config:
106
            config["rope_parameters"][new_name] = yarn_config.pop(old_name)
107
108
109

    assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    return config


def _remap_general_mistral_args(config: dict) -> dict:
    # Mistral key -> HF key
    config_mapping = {
        "dim": "hidden_size",
        "norm_eps": "rms_norm_eps",
        "n_kv_heads": "num_key_value_heads",
        "n_layers": "num_hidden_layers",
        "n_heads": "num_attention_heads",
        "hidden_dim": "intermediate_size",
    }
    # HF key -> (Mistral key, default value)
    top_level_mapping_with_default = {
        "model_type": ("model_type", "transformer"),
        "hidden_act": ("activation", "silu"),
        "tie_word_embeddings": ("tied_embeddings", False),
128
        "max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
129
130
131
132
133
134
135
        "max_position_embeddings": ("max_position_embeddings", 128_000),
    }

    for key, new_key in config_mapping.items():
        if key in config:
            config[new_key] = config.pop(key)

136
    for new_key, (key, default_value) in top_level_mapping_with_default.items():
137
138
139
140
141
142
143
144
145
        config[new_key] = config.pop(key, default_value)

    return config


def _remap_mistral_quantization_args(config: dict) -> dict:
    quantization = config.get("quantization", {})
    if quantization.get("qformat_weight") == "fp8_e4m3":
        # This maps to the FP8 static per-tensor quantization scheme
146
        quantization_config = {"quant_method": "fp8", "activation_scheme": "static"}
147
148
149
150
    elif quantization.get("quant_method") == "compressed-tensors":
        # Pass through the quantization config to compressed-tensors
        quantization_config = quantization
    else:
151
        raise ValueError(f"Found unknown quantization='{quantization}' in config")
152
153
154
155

    config["quantization_config"] = quantization_config

    return config
Patrick von Platen's avatar
Patrick von Platen committed
156
157
158
159
160
161
162
163
164


def _remap_mistral_audio_args(config: dict) -> dict:
    whisper_args = config["multimodal"].pop("whisper_model_args")
    encoder_args = whisper_args["encoder_args"]
    downsample_args = whisper_args["downsample_args"]

    quant_config = config.get("quantization_config")
    config = {
165
        "model_type": "whixtral",
Patrick von Platen's avatar
Patrick von Platen committed
166
        "architectures": ["VoxtralForConditionalGeneration"],
167
168
        "text_config": PretrainedConfig.from_dict(config),
        "audio_config": WhisperConfig(
Patrick von Platen's avatar
Patrick von Platen committed
169
170
171
172
173
174
175
176
177
178
179
            num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
            window_size=encoder_args["audio_encoding_args"]["window_size"],
            sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
            hop_length=encoder_args["audio_encoding_args"]["hop_length"],
            downsample_factor=downsample_args["downsample_factor"],
            d_model=encoder_args["dim"],
            encoder_layers=encoder_args["n_layers"],
            encoder_ffn_dim=encoder_args["hidden_dim"],
            encoder_attention_heads=encoder_args["n_heads"],
            vocab_size=encoder_args["vocab_size"],
            max_source_positions=encoder_args["max_source_positions"],
180
            is_encoder_decoder=False,  # Override WhisperConfig default
181
        ),
Patrick von Platen's avatar
Patrick von Platen committed
182
183
184
185
    }
    if quant_config:
        config["quantization_config"] = quant_config
    return config