Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
14601f5f
Unverified
Commit
14601f5f
authored
Jul 08, 2025
by
Patrick von Platen
Committed by
GitHub
Jul 07, 2025
Browse files
[Config] Refactor mistral configs (#20570)
Signed-off-by:
Patrick von Platen
<
patrick.v.platen@gmail.com
>
parent
042d131f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
167 additions
and
113 deletions
+167
-113
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+3
-0
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+44
-113
vllm/transformers_utils/configs/mistral.py
vllm/transformers_utils/configs/mistral.py
+120
-0
No files found.
vllm/model_executor/models/llama.py
View file @
14601f5f
...
@@ -491,6 +491,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
...
@@ -491,6 +491,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
"qscale_act"
:
"input_scale"
,
"qscale_act"
:
"input_scale"
,
"qscale_weight"
:
"weight_scale"
,
"qscale_weight"
:
"weight_scale"
,
"kv_fake_quantizer.qscale_act"
:
"kv_scale"
,
"kv_fake_quantizer.qscale_act"
:
"kv_scale"
,
"q_fake_quantizer.qscale_act"
:
"attn.q_scale"
,
"k_fake_quantizer.qscale_act"
:
"k_scale"
,
"v_fake_quantizer.qscale_act"
:
"v_scale"
,
"wq"
:
"q_proj"
,
"wq"
:
"q_proj"
,
"wk"
:
"k_proj"
,
"wk"
:
"k_proj"
,
"wv"
:
"v_proj"
,
"wv"
:
"v_proj"
,
...
...
vllm/transformers_utils/config.py
View file @
14601f5f
...
@@ -7,7 +7,7 @@ import os
...
@@ -7,7 +7,7 @@ import os
import
time
import
time
from
functools
import
cache
,
partial
from
functools
import
cache
,
partial
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Literal
,
Optional
,
TypeVar
,
Union
from
typing
import
Any
,
Callable
,
Optional
,
TypeVar
,
Union
import
huggingface_hub
import
huggingface_hub
from
huggingface_hub
import
get_safetensors_metadata
,
hf_hub_download
from
huggingface_hub
import
get_safetensors_metadata
,
hf_hub_download
...
@@ -42,6 +42,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
...
@@ -42,6 +42,7 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
SkyworkR1VChatConfig
,
SolarConfig
,
SkyworkR1VChatConfig
,
SolarConfig
,
Telechat2Config
,
UltravoxConfig
)
Telechat2Config
,
UltravoxConfig
)
# yapf: enable
# yapf: enable
from
vllm.transformers_utils.configs.mistral
import
adapt_config_dict
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.transformers_utils.utils
import
check_gguf_file
from
vllm.utils
import
resolve_obj_by_qualname
from
vllm.utils
import
resolve_obj_by_qualname
...
@@ -394,7 +395,16 @@ def get_config(
...
@@ -394,7 +395,16 @@ def get_config(
config
=
_maybe_remap_hf_config_attrs
(
config
)
config
=
_maybe_remap_hf_config_attrs
(
config
)
elif
config_format
==
ConfigFormat
.
MISTRAL
:
elif
config_format
==
ConfigFormat
.
MISTRAL
:
config
=
load_params_config
(
model
,
revision
,
**
kwargs
)
# This function loads a params.json config which
# should be used when loading models in mistral format
config_dict
=
_download_mistral_config_file
(
model
,
revision
)
if
(
max_position_embeddings
:
=
config_dict
.
get
(
"max_position_embeddings"
))
is
None
:
max_position_embeddings
=
_maybe_retrieve_max_pos_from_hf
(
model
,
revision
,
**
kwargs
)
config_dict
[
"max_position_embeddings"
]
=
max_position_embeddings
config
=
adapt_config_dict
(
config_dict
)
else
:
else
:
supported_formats
=
[
supported_formats
=
[
fmt
.
value
for
fmt
in
ConfigFormat
if
fmt
!=
ConfigFormat
.
AUTO
fmt
.
value
for
fmt
in
ConfigFormat
if
fmt
!=
ConfigFormat
.
AUTO
...
@@ -693,117 +703,6 @@ def maybe_register_config_serialize_by_value() -> None:
...
@@ -693,117 +703,6 @@ def maybe_register_config_serialize_by_value() -> None:
exc_info
=
e
)
exc_info
=
e
)
def
load_params_config
(
model
:
Union
[
str
,
Path
],
revision
:
Optional
[
str
],
**
kwargs
)
->
PretrainedConfig
:
# This function loads a params.json config which
# should be used when loading models in mistral format
config_file_name
=
"params.json"
config_dict
=
get_hf_file_to_dict
(
config_file_name
,
model
,
revision
)
if
config_dict
is
None
:
raise
ValueError
(
f
"Failed to load mistral '
{
config_file_name
}
' config for model "
f
"
{
model
}
. Please check if the model is a mistral-format model "
f
"and if the config file exists."
)
assert
isinstance
(
config_dict
,
dict
)
config_mapping
=
{
"dim"
:
"hidden_size"
,
"norm_eps"
:
"rms_norm_eps"
,
"n_kv_heads"
:
"num_key_value_heads"
,
"n_layers"
:
"num_hidden_layers"
,
"n_heads"
:
"num_attention_heads"
,
"hidden_dim"
:
"intermediate_size"
,
}
def
recurse_elems
(
elem
:
Any
):
if
isinstance
(
elem
,
dict
):
config_dict
=
{}
for
key
,
value
in
elem
.
items
():
key
=
config_mapping
.
get
(
key
,
key
)
config_dict
[
key
]
=
recurse_elems
(
value
)
return
config_dict
else
:
return
elem
config_dict
[
"model_type"
]
=
config_dict
.
get
(
"model_type"
,
"transformer"
)
config_dict
[
"hidden_act"
]
=
config_dict
.
get
(
"activation"
,
"silu"
)
config_dict
[
"tie_word_embeddings"
]
=
config_dict
.
get
(
"tie_embeddings"
,
False
)
if
config_dict
.
get
(
"max_position_embeddings"
)
is
None
:
max_position_embeddings
=
128_000
try
:
trust_remote_code_val
=
kwargs
.
get
(
"trust_remote_code"
,
False
)
hf_config
=
get_config
(
model
=
model
,
trust_remote_code
=
trust_remote_code_val
,
revision
=
revision
,
config_format
=
ConfigFormat
.
HF
)
if
hf_value
:
=
hf_config
.
get_text_config
().
max_position_embeddings
:
max_position_embeddings
=
hf_value
except
Exception
as
e
:
logger
.
warning
(
"The params.json file is missing 'max_position_embeddings'"
" and could not get a value from the HF config."
" Defaulting to 128000"
,
exc_info
=
e
)
config_dict
[
"max_position_embeddings"
]
=
max_position_embeddings
if
config_dict
.
get
(
"quantization"
)
is
not
None
:
quantization
=
config_dict
.
get
(
"quantization"
,
{})
if
quantization
.
get
(
"qformat_weight"
)
==
"fp8_e4m3"
:
# This maps to the FP8 static per-tensor quantization scheme
quantization_config
=
{
"quant_method"
:
"fp8"
,
"activation_scheme"
:
"static"
}
elif
quantization
.
get
(
"quant_method"
)
==
"compressed-tensors"
:
# Pass through the quantization config to compressed-tensors
quantization_config
=
quantization
else
:
raise
ValueError
(
f
"Found unknown quantization='
{
quantization
}
' in config"
)
config_dict
[
"quantization_config"
]
=
quantization_config
config_type
:
Literal
[
"text"
,
"multimodal"
]
=
"multimodal"
if
config_dict
.
get
(
"vision_encoder"
)
is
not
None
else
"text"
if
config_dict
.
get
(
"moe"
)
is
not
None
:
config_dict
[
"architectures"
]
=
[
"MixtralForCausalLM"
]
else
:
config_dict
[
"architectures"
]
=
[
"MistralForCausalLM"
]
if
config_type
==
"multimodal"
:
multimodal_config
=
config_dict
.
pop
(
"vision_encoder"
)
quantization_config
=
config_dict
.
get
(
"quantization_config"
,
{})
config_dict
=
{
"text_config"
:
config_dict
,
"vision_config"
:
multimodal_config
}
config_dict
[
"architectures"
]
=
[
"PixtralForConditionalGeneration"
]
config_dict
[
"model_type"
]
=
"pixtral"
if
quantization_config
:
config_dict
[
"quantization_config"
]
=
quantization_config
config_dict
.
update
(
kwargs
)
config_dict
=
recurse_elems
(
config_dict
)
# transform to HF config format
if
config_type
==
"multimodal"
:
config_dict
[
"text_config"
]
=
PretrainedConfig
(
**
config_dict
[
"text_config"
])
config_dict
[
"vision_config"
]
=
PretrainedConfig
(
**
config_dict
[
"vision_config"
])
return
PretrainedConfig
(
**
config_dict
)
def
get_hf_image_processor_config
(
def
get_hf_image_processor_config
(
model
:
Union
[
str
,
Path
],
model
:
Union
[
str
,
Path
],
hf_token
:
Optional
[
Union
[
bool
,
str
]]
=
None
,
hf_token
:
Optional
[
Union
[
bool
,
str
]]
=
None
,
...
@@ -920,3 +819,35 @@ def try_get_tokenizer_config(
...
@@ -920,3 +819,35 @@ def try_get_tokenizer_config(
)
)
except
Exception
:
except
Exception
:
return
None
return
None
def
_download_mistral_config_file
(
model
,
revision
)
->
dict
:
config_file_name
=
"params.json"
config_dict
=
get_hf_file_to_dict
(
config_file_name
,
model
,
revision
)
if
config_dict
is
None
:
raise
ValueError
(
f
"Failed to load mistral '
{
config_file_name
}
' config for model "
f
"
{
model
}
. Please check if the model is a mistral-format model "
f
"and if the config file exists."
)
assert
isinstance
(
config_dict
,
dict
)
return
config_dict
def
_maybe_retrieve_max_pos_from_hf
(
model
,
revision
,
**
kwargs
)
->
int
:
max_position_embeddings
=
128_000
try
:
trust_remote_code_val
=
kwargs
.
get
(
"trust_remote_code"
,
False
)
hf_config
=
get_config
(
model
=
model
,
trust_remote_code
=
trust_remote_code_val
,
revision
=
revision
,
config_format
=
ConfigFormat
.
HF
)
if
hf_value
:
=
hf_config
.
get_text_config
().
max_position_embeddings
:
max_position_embeddings
=
hf_value
except
Exception
as
e
:
logger
.
warning
(
"The params.json file is missing 'max_position_embeddings'"
" and could not get a value from the HF config."
" Defaulting to 128000"
,
exc_info
=
e
)
return
max_position_embeddings
vllm/transformers_utils/configs/mistral.py
0 → 100644
View file @
14601f5f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
typing
import
Any
from
transformers
import
PretrainedConfig
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
def
adapt_config_dict
(
config_dict
:
dict
[
str
,
Any
],
**
kwargs
)
->
PretrainedConfig
:
config_dict
.
update
(
kwargs
)
config_dict
=
_remap_general_mistral_args
(
config_dict
)
if
bool
(
config_dict
.
get
(
"quantization"
)):
config_dict
=
_remap_mistral_quantization_args
(
config_dict
)
if
bool
(
config_dict
.
get
(
"moe"
)):
config_dict
[
"architectures"
]
=
[
"MixtralForCausalLM"
]
else
:
config_dict
[
"architectures"
]
=
[
"MistralForCausalLM"
]
if
bool
(
config_dict
.
get
(
"yarn"
)):
config_dict
=
_remap_mistral_yarn_args
(
config_dict
)
if
bool
((
config_dict
.
get
(
"multimodal"
)
or
{}).
get
(
"vision_encoder_args"
)
or
config_dict
.
get
(
"vision_encoder"
)):
config_dict
=
_remap_mistral_vision_args
(
config_dict
)
config
=
PretrainedConfig
.
from_dict
(
config_dict
)
logger
.
debug
(
"Initialized config"
,
config
)
return
config
def
_remap_mistral_vision_args
(
config
:
dict
)
->
dict
:
if
config
.
get
(
"multimodal"
):
vision_config
=
config
.
pop
(
"multimodal"
)
else
:
vision_config
=
config
.
pop
(
"vision_encoder"
)
quant_config
=
config
.
get
(
"quantization_config"
)
config
=
{
"model_type"
:
"pixtral"
,
"architectures"
:
[
"PixtralForConditionalGeneration"
],
"text_config"
:
PretrainedConfig
.
from_dict
(
config
),
"vision_config"
:
PretrainedConfig
.
from_dict
(
vision_config
),
}
if
quant_config
:
config
[
"quantization_config"
]
=
quant_config
return
config
def
_remap_mistral_yarn_args
(
config
:
dict
)
->
dict
:
# Direct remaps: yarn.X -> rope_scaling.Y
# Source keys are from mistral.model.args.YarnArgs
_map
=
{
"beta"
:
"beta_fast"
,
"alpha"
:
"beta_slow"
,
}
yarn_config
=
config
.
get
(
"yarn"
)
or
{}
renamed_yarn_config
=
{
_map
.
get
(
k
,
k
):
v
for
k
,
v
in
yarn_config
.
items
()}
config
[
"rope_scaling"
]
=
{
"rope_type"
:
"yarn"
,
"mscale_all_dim"
:
1
,
# We hardcoded this to 1
**
renamed_yarn_config
}
return
config
def
_remap_general_mistral_args
(
config
:
dict
)
->
dict
:
# Mistral key -> HF key
config_mapping
=
{
"dim"
:
"hidden_size"
,
"norm_eps"
:
"rms_norm_eps"
,
"n_kv_heads"
:
"num_key_value_heads"
,
"n_layers"
:
"num_hidden_layers"
,
"n_heads"
:
"num_attention_heads"
,
"hidden_dim"
:
"intermediate_size"
,
}
# HF key -> (Mistral key, default value)
top_level_mapping_with_default
=
{
"model_type"
:
(
"model_type"
,
"transformer"
),
"hidden_act"
:
(
"activation"
,
"silu"
),
"tie_word_embeddings"
:
(
"tied_embeddings"
,
False
),
"max_seq_len"
:
(
"max_seq_len"
,
128_000
),
"max_position_embeddings"
:
(
"max_position_embeddings"
,
128_000
),
}
for
key
,
new_key
in
config_mapping
.
items
():
if
key
in
config
:
config
[
new_key
]
=
config
.
pop
(
key
)
for
new_key
,
(
key
,
default_value
)
in
top_level_mapping_with_default
.
items
():
config
[
new_key
]
=
config
.
pop
(
key
,
default_value
)
return
config
def
_remap_mistral_quantization_args
(
config
:
dict
)
->
dict
:
quantization
=
config
.
get
(
"quantization"
,
{})
if
quantization
.
get
(
"qformat_weight"
)
==
"fp8_e4m3"
:
# This maps to the FP8 static per-tensor quantization scheme
quantization_config
=
{
"quant_method"
:
"fp8"
,
"activation_scheme"
:
"static"
}
elif
quantization
.
get
(
"quant_method"
)
==
"compressed-tensors"
:
# Pass through the quantization config to compressed-tensors
quantization_config
=
quantization
else
:
raise
ValueError
(
f
"Found unknown quantization='
{
quantization
}
' in config"
)
config
[
"quantization_config"
]
=
quantization_config
return
config
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment