mistral.py 5.62 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any

Patrick von Platen's avatar
Patrick von Platen committed
5
from transformers import PretrainedConfig, WhisperConfig
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

from vllm.logger import init_logger

logger = init_logger(__name__)


def adapt_config_dict(config_dict: dict[str, Any],
                      **kwargs) -> PretrainedConfig:
    config_dict.update(kwargs)
    config_dict = _remap_general_mistral_args(config_dict)

    if bool(config_dict.get("quantization")):
        config_dict = _remap_mistral_quantization_args(config_dict)

    if bool(config_dict.get("moe")):
        config_dict["architectures"] = ["MixtralForCausalLM"]
    else:
        config_dict["architectures"] = ["MistralForCausalLM"]

    if bool(config_dict.get("yarn")):
        config_dict = _remap_mistral_yarn_args(config_dict)
Patrick von Platen's avatar
Patrick von Platen committed
27
28
29
30
31
32
33
34
35
36
37
38

    is_vision = ((config_dict.get("multimodal")
                  or {}).get("vision_encoder_args")
                 or config_dict.get("vision_encoder"))
    is_audio = bool(
        ((config_dict.get("multimodal") or {}).get("whisper_model_args")
         or {}).get("encoder_args"))

    assert not (is_vision and is_audio), \
        "Vision and audio are mutually exclusive"

    if is_vision:
39
        config_dict = _remap_mistral_vision_args(config_dict)
Patrick von Platen's avatar
Patrick von Platen committed
40
41
    if is_audio:
        config_dict = _remap_mistral_audio_args(config_dict)
42
43
44

    config = PretrainedConfig.from_dict(config_dict)

45
    logger.debug("Initialized config %s", config)
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

    return config


def _remap_mistral_vision_args(config: dict) -> dict:
    if config.get("multimodal"):
        vision_config = config.pop("multimodal")
    else:
        vision_config = config.pop("vision_encoder")

    quant_config = config.get("quantization_config")
    config = {
        "model_type": "pixtral",
        "architectures": ["PixtralForConditionalGeneration"],
        "text_config": PretrainedConfig.from_dict(config),
        "vision_config": PretrainedConfig.from_dict(vision_config),
    }
    if quant_config:
        config["quantization_config"] = quant_config
    return config


def _remap_mistral_yarn_args(config: dict) -> dict:
    # Direct remaps: yarn.X -> rope_scaling.Y
    # Source keys are from mistral.model.args.YarnArgs
    _map = {
        "beta": "beta_fast",
        "alpha": "beta_slow",
    }
    yarn_config = config.get("yarn") or {}
    renamed_yarn_config = {_map.get(k, k): v for k, v in yarn_config.items()}
    config["rope_scaling"] = {
        "rope_type": "yarn",
        "mscale_all_dim": 1,  # We hardcoded this to 1
        **renamed_yarn_config
    }
    return config


def _remap_general_mistral_args(config: dict) -> dict:
    # Mistral key -> HF key
    config_mapping = {
        "dim": "hidden_size",
        "norm_eps": "rms_norm_eps",
        "n_kv_heads": "num_key_value_heads",
        "n_layers": "num_hidden_layers",
        "n_heads": "num_attention_heads",
        "hidden_dim": "intermediate_size",
    }
    # HF key -> (Mistral key, default value)
    top_level_mapping_with_default = {
        "model_type": ("model_type", "transformer"),
        "hidden_act": ("activation", "silu"),
        "tie_word_embeddings": ("tied_embeddings", False),
        "max_seq_len": ("max_seq_len", 128_000),
        "max_position_embeddings": ("max_position_embeddings", 128_000),
    }

    for key, new_key in config_mapping.items():
        if key in config:
            config[new_key] = config.pop(key)

    for new_key, (key,
                  default_value) in top_level_mapping_with_default.items():
        config[new_key] = config.pop(key, default_value)

    return config


def _remap_mistral_quantization_args(config: dict) -> dict:
    quantization = config.get("quantization", {})
    if quantization.get("qformat_weight") == "fp8_e4m3":
        # This maps to the FP8 static per-tensor quantization scheme
        quantization_config = {
            "quant_method": "fp8",
            "activation_scheme": "static"
        }
    elif quantization.get("quant_method") == "compressed-tensors":
        # Pass through the quantization config to compressed-tensors
        quantization_config = quantization
    else:
        raise ValueError(
            f"Found unknown quantization='{quantization}' in config")

    config["quantization_config"] = quantization_config

    return config
Patrick von Platen's avatar
Patrick von Platen committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159


def _remap_mistral_audio_args(config: dict) -> dict:
    whisper_args = config["multimodal"].pop("whisper_model_args")
    encoder_args = whisper_args["encoder_args"]
    downsample_args = whisper_args["downsample_args"]

    quant_config = config.get("quantization_config")
    config = {
        "model_type":
        "whixtral",
        "architectures": ["VoxtralForConditionalGeneration"],
        "text_config":
        PretrainedConfig.from_dict(config),
        "audio_config":
        WhisperConfig(
            num_mel_bins=encoder_args["audio_encoding_args"]["num_mel_bins"],
            window_size=encoder_args["audio_encoding_args"]["window_size"],
            sampling_rate=encoder_args["audio_encoding_args"]["sampling_rate"],
            hop_length=encoder_args["audio_encoding_args"]["hop_length"],
            downsample_factor=downsample_args["downsample_factor"],
            d_model=encoder_args["dim"],
            encoder_layers=encoder_args["n_layers"],
            encoder_ffn_dim=encoder_args["hidden_dim"],
            encoder_attention_heads=encoder_args["n_heads"],
            vocab_size=encoder_args["vocab_size"],
            max_source_positions=encoder_args["max_source_positions"],
160
            is_encoder_decoder=False,  # Override WhisperConfig default
Patrick von Platen's avatar
Patrick von Platen committed
161
162
163
164
165
        )
    }
    if quant_config:
        config["quantization_config"] = quant_config
    return config