Unverified Commit 14601f5f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Config] Refactor mistral configs (#20570)


Signed-off-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 042d131f
...@@ -491,6 +491,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -491,6 +491,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"qscale_act": "input_scale", "qscale_act": "input_scale",
"qscale_weight": "weight_scale", "qscale_weight": "weight_scale",
"kv_fake_quantizer.qscale_act": "kv_scale", "kv_fake_quantizer.qscale_act": "kv_scale",
"q_fake_quantizer.qscale_act": "attn.q_scale",
"k_fake_quantizer.qscale_act": "k_scale",
"v_fake_quantizer.qscale_act": "v_scale",
"wq": "q_proj", "wq": "q_proj",
"wk": "k_proj", "wk": "k_proj",
"wv": "v_proj", "wv": "v_proj",
......
...@@ -7,7 +7,7 @@ import os ...@@ -7,7 +7,7 @@ import os
import time import time
from functools import cache, partial from functools import cache, partial
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Literal, Optional, TypeVar, Union from typing import Any, Callable, Optional, TypeVar, Union
import huggingface_hub import huggingface_hub
from huggingface_hub import get_safetensors_metadata, hf_hub_download from huggingface_hub import get_safetensors_metadata, hf_hub_download
...@@ -42,6 +42,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, ...@@ -42,6 +42,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
SkyworkR1VChatConfig, SolarConfig, SkyworkR1VChatConfig, SolarConfig,
Telechat2Config, UltravoxConfig) Telechat2Config, UltravoxConfig)
# yapf: enable # yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.utils import resolve_obj_by_qualname from vllm.utils import resolve_obj_by_qualname
...@@ -394,7 +395,16 @@ def get_config( ...@@ -394,7 +395,16 @@ def get_config(
config = _maybe_remap_hf_config_attrs(config) config = _maybe_remap_hf_config_attrs(config)
elif config_format == ConfigFormat.MISTRAL: elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, **kwargs) # This function loads a params.json config which
# should be used when loading models in mistral format
config_dict = _download_mistral_config_file(model, revision)
if (max_position_embeddings :=
config_dict.get("max_position_embeddings")) is None:
max_position_embeddings = _maybe_retrieve_max_pos_from_hf(
model, revision, **kwargs)
config_dict["max_position_embeddings"] = max_position_embeddings
config = adapt_config_dict(config_dict)
else: else:
supported_formats = [ supported_formats = [
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
...@@ -693,117 +703,6 @@ def maybe_register_config_serialize_by_value() -> None: ...@@ -693,117 +703,6 @@ def maybe_register_config_serialize_by_value() -> None:
exc_info=e) exc_info=e)
def load_params_config(model: Union[str, Path], revision: Optional[str],
**kwargs) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format
config_file_name = "params.json"
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
if config_dict is None:
raise ValueError(
f"Failed to load mistral '{config_file_name}' config for model "
f"{model}. Please check if the model is a mistral-format model "
f"and if the config file exists.")
assert isinstance(config_dict, dict)
config_mapping = {
"dim": "hidden_size",
"norm_eps": "rms_norm_eps",
"n_kv_heads": "num_key_value_heads",
"n_layers": "num_hidden_layers",
"n_heads": "num_attention_heads",
"hidden_dim": "intermediate_size",
}
def recurse_elems(elem: Any):
if isinstance(elem, dict):
config_dict = {}
for key, value in elem.items():
key = config_mapping.get(key, key)
config_dict[key] = recurse_elems(value)
return config_dict
else:
return elem
config_dict["model_type"] = config_dict.get("model_type", "transformer")
config_dict["hidden_act"] = config_dict.get("activation", "silu")
config_dict["tie_word_embeddings"] = config_dict.get(
"tie_embeddings", False)
if config_dict.get("max_position_embeddings") is None:
max_position_embeddings = 128_000
try:
trust_remote_code_val = kwargs.get("trust_remote_code", False)
hf_config = get_config(model=model,
trust_remote_code=trust_remote_code_val,
revision=revision,
config_format=ConfigFormat.HF)
if hf_value := hf_config.get_text_config().max_position_embeddings:
max_position_embeddings = hf_value
except Exception as e:
logger.warning(
"The params.json file is missing 'max_position_embeddings'"
" and could not get a value from the HF config."
" Defaulting to 128000",
exc_info=e)
config_dict["max_position_embeddings"] = max_position_embeddings
if config_dict.get("quantization") is not None:
quantization = config_dict.get("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
# This maps to the FP8 static per-tensor quantization scheme
quantization_config = {
"quant_method": "fp8",
"activation_scheme": "static"
}
elif quantization.get("quant_method") == "compressed-tensors":
# Pass through the quantization config to compressed-tensors
quantization_config = quantization
else:
raise ValueError(
f"Found unknown quantization='{quantization}' in config")
config_dict["quantization_config"] = quantization_config
config_type: Literal["text",
"multimodal"] = "multimodal" if config_dict.get(
"vision_encoder") is not None else "text"
if config_dict.get("moe") is not None:
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
if config_type == "multimodal":
multimodal_config = config_dict.pop("vision_encoder")
quantization_config = config_dict.get("quantization_config", {})
config_dict = {
"text_config": config_dict,
"vision_config": multimodal_config
}
config_dict["architectures"] = ["PixtralForConditionalGeneration"]
config_dict["model_type"] = "pixtral"
if quantization_config:
config_dict["quantization_config"] = quantization_config
config_dict.update(kwargs)
config_dict = recurse_elems(config_dict)
# transform to HF config format
if config_type == "multimodal":
config_dict["text_config"] = PretrainedConfig(
**config_dict["text_config"])
config_dict["vision_config"] = PretrainedConfig(
**config_dict["vision_config"])
return PretrainedConfig(**config_dict)
def get_hf_image_processor_config( def get_hf_image_processor_config(
model: Union[str, Path], model: Union[str, Path],
hf_token: Optional[Union[bool, str]] = None, hf_token: Optional[Union[bool, str]] = None,
...@@ -920,3 +819,35 @@ def try_get_tokenizer_config( ...@@ -920,3 +819,35 @@ def try_get_tokenizer_config(
) )
except Exception: except Exception:
return None return None
def _download_mistral_config_file(model, revision) -> dict:
config_file_name = "params.json"
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
if config_dict is None:
raise ValueError(
f"Failed to load mistral '{config_file_name}' config for model "
f"{model}. Please check if the model is a mistral-format model "
f"and if the config file exists.")
assert isinstance(config_dict, dict)
return config_dict
def _maybe_retrieve_max_pos_from_hf(model, revision, **kwargs) -> int:
max_position_embeddings = 128_000
try:
trust_remote_code_val = kwargs.get("trust_remote_code", False)
hf_config = get_config(model=model,
trust_remote_code=trust_remote_code_val,
revision=revision,
config_format=ConfigFormat.HF)
if hf_value := hf_config.get_text_config().max_position_embeddings:
max_position_embeddings = hf_value
except Exception as e:
logger.warning(
"The params.json file is missing 'max_position_embeddings'"
" and could not get a value from the HF config."
" Defaulting to 128000",
exc_info=e)
return max_position_embeddings
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from transformers import PretrainedConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
def adapt_config_dict(config_dict: dict[str, Any],
**kwargs) -> PretrainedConfig:
config_dict.update(kwargs)
config_dict = _remap_general_mistral_args(config_dict)
if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict)
if bool(config_dict.get("moe")):
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
if bool(config_dict.get("yarn")):
config_dict = _remap_mistral_yarn_args(config_dict)
if bool((config_dict.get("multimodal") or {}).get("vision_encoder_args")
or config_dict.get("vision_encoder")):
config_dict = _remap_mistral_vision_args(config_dict)
config = PretrainedConfig.from_dict(config_dict)
logger.debug("Initialized config", config)
return config
def _remap_mistral_vision_args(config: dict) -> dict:
if config.get("multimodal"):
vision_config = config.pop("multimodal")
else:
vision_config = config.pop("vision_encoder")
quant_config = config.get("quantization_config")
config = {
"model_type": "pixtral",
"architectures": ["PixtralForConditionalGeneration"],
"text_config": PretrainedConfig.from_dict(config),
"vision_config": PretrainedConfig.from_dict(vision_config),
}
if quant_config:
config["quantization_config"] = quant_config
return config
def _remap_mistral_yarn_args(config: dict) -> dict:
# Direct remaps: yarn.X -> rope_scaling.Y
# Source keys are from mistral.model.args.YarnArgs
_map = {
"beta": "beta_fast",
"alpha": "beta_slow",
}
yarn_config = config.get("yarn") or {}
renamed_yarn_config = {_map.get(k, k): v for k, v in yarn_config.items()}
config["rope_scaling"] = {
"rope_type": "yarn",
"mscale_all_dim": 1, # We hardcoded this to 1
**renamed_yarn_config
}
return config
def _remap_general_mistral_args(config: dict) -> dict:
# Mistral key -> HF key
config_mapping = {
"dim": "hidden_size",
"norm_eps": "rms_norm_eps",
"n_kv_heads": "num_key_value_heads",
"n_layers": "num_hidden_layers",
"n_heads": "num_attention_heads",
"hidden_dim": "intermediate_size",
}
# HF key -> (Mistral key, default value)
top_level_mapping_with_default = {
"model_type": ("model_type", "transformer"),
"hidden_act": ("activation", "silu"),
"tie_word_embeddings": ("tied_embeddings", False),
"max_seq_len": ("max_seq_len", 128_000),
"max_position_embeddings": ("max_position_embeddings", 128_000),
}
for key, new_key in config_mapping.items():
if key in config:
config[new_key] = config.pop(key)
for new_key, (key,
default_value) in top_level_mapping_with_default.items():
config[new_key] = config.pop(key, default_value)
return config
def _remap_mistral_quantization_args(config: dict) -> dict:
quantization = config.get("quantization", {})
if quantization.get("qformat_weight") == "fp8_e4m3":
# This maps to the FP8 static per-tensor quantization scheme
quantization_config = {
"quant_method": "fp8",
"activation_scheme": "static"
}
elif quantization.get("quant_method") == "compressed-tensors":
# Pass through the quantization config to compressed-tensors
quantization_config = quantization
else:
raise ValueError(
f"Found unknown quantization='{quantization}' in config")
config["quantization_config"] = quantization_config
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment