Unverified Commit d1cf8214 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Use HF config fields as fallback when loading Mistral config (#29239)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 730bd353
......@@ -754,6 +754,7 @@ steps:
torch_nightly: true
source_file_dependencies:
- vllm/model_executor/models/
- vllm/transformers_utils/
- tests/models/test_initialization.py
commands:
# Only when vLLM model source is modified - test initialization of a large
......
......@@ -691,6 +691,7 @@ steps:
torch_nightly: true
source_file_dependencies:
- vllm/model_executor/models/
- vllm/transformers_utils/
- tests/models/test_initialization.py
commands:
# Only when vLLM model source is modified - test initialization of a large
......
......@@ -204,7 +204,19 @@ class MistralConfigParser(ConfigParserBase):
from vllm.transformers_utils.configs.mistral import adapt_config_dict
config = adapt_config_dict(config_dict)
# Get missing fields from HF config if available
try:
hf_config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
**kwargs,
)
except OSError: # Not found
hf_config_dict = {}
config = adapt_config_dict(config_dict, defaults=hf_config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
......
......@@ -9,14 +9,18 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig:
config_dict.update(kwargs)
def adapt_config_dict(
config_dict: dict[str, Any],
defaults: dict[str, Any],
) -> PretrainedConfig:
config_dict = _remap_general_mistral_args(config_dict)
if bool(config_dict.get("quantization")):
config_dict = _remap_mistral_quantization_args(config_dict)
if bool(config_dict.get("moe")):
if config_dict.get("model_type") == "mamba":
config_dict["architectures"] = ["Mamba2ForCausalLM"]
elif bool(config_dict.get("moe")):
config_dict["architectures"] = ["MixtralForCausalLM"]
else:
config_dict["architectures"] = ["MistralForCausalLM"]
......@@ -52,6 +56,9 @@ def adapt_config_dict(config_dict: dict[str, Any], **kwargs) -> PretrainedConfig
if is_audio:
config_dict = _remap_mistral_audio_args(config_dict)
for k, v in defaults.items():
config_dict.setdefault(k, v)
config = PretrainedConfig.from_dict(config_dict)
logger.debug("Initialized config %s", config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment