Unverified Commit ffbc2e5b authored by Julien Denize's avatar Julien Denize Committed by GitHub
Browse files

Patch Mistral config (#37104)


Signed-off-by: default avatarjuliendenize <julien.denize@mistral.ai>
parent f9e6db30
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os import os
from collections.abc import Callable from collections.abc import Callable, Iterator
from contextlib import contextmanager
from dataclasses import asdict from dataclasses import asdict
from functools import cache, partial from functools import cache, partial
from importlib.metadata import version from importlib.metadata import version
...@@ -10,8 +11,10 @@ from pathlib import Path ...@@ -10,8 +11,10 @@ from pathlib import Path
from typing import Any, Literal, TypeAlias from typing import Any, Literal, TypeAlias
import huggingface_hub import huggingface_hub
from huggingface_hub import get_safetensors_metadata import torch
from huggingface_hub import constants, get_safetensors_metadata
from packaging.version import Version from packaging.version import Version
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import GenerationConfig, PretrainedConfig from transformers import GenerationConfig, PretrainedConfig
from transformers.models.auto.image_processing_auto import get_image_processor_config from transformers.models.auto.image_processing_auto import get_image_processor_config
from transformers.models.auto.modeling_auto import ( from transformers.models.auto.modeling_auto import (
...@@ -28,6 +31,7 @@ from vllm.transformers_utils.utils import ( ...@@ -28,6 +31,7 @@ from vllm.transformers_utils.utils import (
parse_safetensors_file_metadata, parse_safetensors_file_metadata,
without_trust_remote_code, without_trust_remote_code,
) )
from vllm.utils.torch_utils import common_broadcastable_dtype
from .config_parser_base import ConfigParserBase from .config_parser_base import ConfigParserBase
from .gguf_utils import ( from .gguf_utils import (
...@@ -135,6 +139,19 @@ def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool: ...@@ -135,6 +139,19 @@ def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool:
return set(rope_parameters.keys()).issubset(ALLOWED_ATTENTION_LAYER_TYPES) return set(rope_parameters.keys()).issubset(ALLOWED_ATTENTION_LAYER_TYPES)
@contextmanager
def _mistral_patch_hf_hub_constants() -> Iterator[None]:
hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
try:
yield
finally:
constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
class HFConfigParser(ConfigParserBase): class HFConfigParser(ConfigParserBase):
def parse( def parse(
self, self,
...@@ -245,6 +262,25 @@ class MistralConfigParser(ConfigParserBase): ...@@ -245,6 +262,25 @@ class MistralConfigParser(ConfigParserBase):
except OSError: # Not found except OSError: # Not found
hf_config_dict = {} hf_config_dict = {}
if config_dict.get("dtype") is None:
with _mistral_patch_hf_hub_constants():
model_str = model if isinstance(model, str) else model.as_posix()
param_mt = get_safetensors_params_metadata(model_str, revision=revision)
if param_mt:
param_dtypes: set[torch.dtype] = {
_SAFETENSORS_TO_TORCH_DTYPE[dtype]
for info in param_mt.values()
if (dtype := info.get("dtype", None))
and dtype in _SAFETENSORS_TO_TORCH_DTYPE
}
if param_dtypes:
config_dict["dtype"] = common_broadcastable_dtype(param_dtypes)
logger.info_once(
"Inferred from consolidated*.safetensors files "
f"{config_dict['dtype']} dtype."
)
config = adapt_config_dict(config_dict, defaults=hf_config_dict) config = adapt_config_dict(config_dict, defaults=hf_config_dict)
return config_dict, config return config_dict, config
......
...@@ -113,12 +113,13 @@ def _remap_mistral_vision_args(config: dict) -> dict: ...@@ -113,12 +113,13 @@ def _remap_mistral_vision_args(config: dict) -> dict:
def _remap_mistral_yarn_args(config: dict) -> dict: def _remap_mistral_yarn_args(config: dict) -> dict:
yarn_config_map = { yarn_config_map = {
"factor": "factor", "factor": ("factor", float),
"original_max_position_embeddings": "original_max_position_embeddings", "original_max_position_embeddings": ("original_max_position_embeddings", int),
"beta": "beta_fast", "beta": ("beta_fast", float),
"alpha": "beta_slow", "alpha": ("beta_slow", float),
"apply_scale": "apply_yarn_scaling", "apply_scale": ("apply_yarn_scaling", bool),
} }
yarn_config = config.get("yarn") or {} yarn_config = config.get("yarn") or {}
config["rope_parameters"] = { config["rope_parameters"] = {
"rope_type": "yarn", "rope_type": "yarn",
...@@ -128,9 +129,10 @@ def _remap_mistral_yarn_args(config: dict) -> dict: ...@@ -128,9 +129,10 @@ def _remap_mistral_yarn_args(config: dict) -> dict:
if rope_theta := config.pop("rope_theta", None): if rope_theta := config.pop("rope_theta", None):
config["rope_parameters"]["rope_theta"] = rope_theta config["rope_parameters"]["rope_theta"] = rope_theta
for old_name, new_name in yarn_config_map.items(): for old_name, (new_name, cast) in yarn_config_map.items():
if old_name in yarn_config: if old_name in yarn_config:
config["rope_parameters"][new_name] = yarn_config.pop(old_name) # Cast to remove Transformers > v5 type warnings
config["rope_parameters"][new_name] = cast(yarn_config.pop(old_name))
assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}" assert len(yarn_config) == 0, f"Unparsed yarn config: {yarn_config}"
...@@ -154,6 +156,7 @@ def _remap_general_mistral_args(config: dict) -> dict: ...@@ -154,6 +156,7 @@ def _remap_general_mistral_args(config: dict) -> dict:
"tie_word_embeddings": ("tied_embeddings", False), "tie_word_embeddings": ("tied_embeddings", False),
"max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)), "max_seq_len": ("max_seq_len", config.get("max_position_embeddings", 128_000)),
"max_position_embeddings": ("max_position_embeddings", 128_000), "max_position_embeddings": ("max_position_embeddings", 128_000),
"dtype": ("dtype", config.get("dtype")),
} }
for key, new_key in config_mapping.items(): for key, new_key in config_mapping.items():
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterator
from contextlib import contextmanager
from typing import final from typing import final
import torch import torch
from huggingface_hub import constants
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -25,22 +22,6 @@ from vllm.utils.torch_utils import common_broadcastable_dtype ...@@ -25,22 +22,6 @@ from vllm.utils.torch_utils import common_broadcastable_dtype
logger = init_logger(__name__) logger = init_logger(__name__)
@contextmanager
def _maybe_patch_hf_hub_constants(config_format: ConfigFormat) -> Iterator[None]:
if config_format == "mistral":
hf_safetensors_single_file = constants.SAFETENSORS_SINGLE_FILE
hf_safetensors_index_file = constants.SAFETENSORS_INDEX_FILE
constants.SAFETENSORS_SINGLE_FILE = "consolidated.safetensors"
constants.SAFETENSORS_INDEX_FILE = "consolidated.safetensors.index.json"
try:
yield
finally:
constants.SAFETENSORS_SINGLE_FILE = hf_safetensors_single_file
constants.SAFETENSORS_INDEX_FILE = hf_safetensors_index_file
else:
yield
class ModelArchConfigConvertorBase: class ModelArchConfigConvertorBase:
def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig): def __init__(self, hf_config: PretrainedConfig, hf_text_config: PretrainedConfig):
self.hf_config = hf_config self.hf_config = hf_config
...@@ -164,8 +145,7 @@ class ModelArchConfigConvertorBase: ...@@ -164,8 +145,7 @@ class ModelArchConfigConvertorBase:
# Try to read the dtype of the weights if they are in safetensors format # Try to read the dtype of the weights if they are in safetensors format
if config_dtype is None: if config_dtype is None:
with _maybe_patch_hf_hub_constants(config_format): param_mt = get_safetensors_params_metadata(model_id, revision=revision)
param_mt = get_safetensors_params_metadata(model_id, revision=revision)
if param_mt: if param_mt:
param_dtypes: set[torch.dtype] = { param_dtypes: set[torch.dtype] = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment