Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ffbc2e5b
Unverified
Commit
ffbc2e5b
authored
Mar 16, 2026
by
Julien Denize
Committed by
GitHub
Mar 16, 2026
Browse files
Patch Mistral config (#37104)
Signed-off-by:
juliendenize
<
julien.denize@mistral.ai
>
parent
f9e6db30
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
30 deletions
+49
-30
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+38
-2
vllm/transformers_utils/configs/mistral.py
vllm/transformers_utils/configs/mistral.py
+10
-7
vllm/transformers_utils/model_arch_config_convertor.py
vllm/transformers_utils/model_arch_config_convertor.py
+1
-21
No files found.
vllm/transformers_utils/config.py
View file @
ffbc2e5b
...
@@ -2,7 +2,8 @@
...
@@ -2,7 +2,8 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
import
os
from
collections.abc
import
Callable
from
collections.abc
import
Callable
,
Iterator
from
contextlib
import
contextmanager
from
dataclasses
import
asdict
from
dataclasses
import
asdict
from
functools
import
cache
,
partial
from
functools
import
cache
,
partial
from
importlib.metadata
import
version
from
importlib.metadata
import
version
...
@@ -10,8 +11,10 @@ from pathlib import Path
...
@@ -10,8 +11,10 @@ from pathlib import Path
from
typing
import
Any
,
Literal
,
TypeAlias
from
typing
import
Any
,
Literal
,
TypeAlias
import
huggingface_hub
import
huggingface_hub
from
huggingface_hub
import
get_safetensors_metadata
import
torch
from
huggingface_hub
import
constants
,
get_safetensors_metadata
from
packaging.version
import
Version
from
packaging.version
import
Version
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
from
transformers
import
GenerationConfig
,
PretrainedConfig
from
transformers
import
GenerationConfig
,
PretrainedConfig
from
transformers.models.auto.image_processing_auto
import
get_image_processor_config
from
transformers.models.auto.image_processing_auto
import
get_image_processor_config
from
transformers.models.auto.modeling_auto
import
(
from
transformers.models.auto.modeling_auto
import
(
...
@@ -28,6 +31,7 @@ from vllm.transformers_utils.utils import (
...
@@ -28,6 +31,7 @@ from vllm.transformers_utils.utils import (
parse_safetensors_file_metadata
,
parse_safetensors_file_metadata
,
without_trust_remote_code
,
without_trust_remote_code
,
)
)
from
vllm.utils.torch_utils
import
common_broadcastable_dtype
from
.config_parser_base
import
ConfigParserBase
from
.config_parser_base
import
ConfigParserBase
from
.gguf_utils
import
(
from
.gguf_utils
import
(
...
@@ -135,6 +139,19 @@ def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool:
...
@@ -135,6 +139,19 @@ def is_rope_parameters_nested(rope_parameters: dict[str, Any]) -> bool:
return
set
(
rope_parameters
.
keys
()).
issubset
(
ALLOWED_ATTENTION_LAYER_TYPES
)
return
set
(
rope_parameters
.
keys
()).
issubset
(
ALLOWED_ATTENTION_LAYER_TYPES
)
@
contextmanager
def
_mistral_patch_hf_hub_constants
()
->
Iterator
[
None
]:
hf_safetensors_single_file
=
constants
.
SAFETENSORS_SINGLE_FILE
hf_safetensors_index_file
=
constants
.
SAFETENSORS_INDEX_FILE
constants
.
SAFETENSORS_SINGLE_FILE
=
"consolidated.safetensors"
constants
.
SAFETENSORS_INDEX_FILE
=
"consolidated.safetensors.index.json"
try
:
yield
finally
:
constants
.
SAFETENSORS_SINGLE_FILE
=
hf_safetensors_single_file
constants
.
SAFETENSORS_INDEX_FILE
=
hf_safetensors_index_file
class
HFConfigParser
(
ConfigParserBase
):
class
HFConfigParser
(
ConfigParserBase
):
def
parse
(
def
parse
(
self
,
self
,
...
@@ -245,6 +262,25 @@ class MistralConfigParser(ConfigParserBase):
...
@@ -245,6 +262,25 @@ class MistralConfigParser(ConfigParserBase):
except
OSError
:
# Not found
except
OSError
:
# Not found
hf_config_dict
=
{}
hf_config_dict
=
{}
if
config_dict
.
get
(
"dtype"
)
is
None
:
with
_mistral_patch_hf_hub_constants
():
model_str
=
model
if
isinstance
(
model
,
str
)
else
model
.
as_posix
()
param_mt
=
get_safetensors_params_metadata
(
model_str
,
revision
=
revision
)
if
param_mt
:
param_dtypes
:
set
[
torch
.
dtype
]
=
{
_SAFETENSORS_TO_TORCH_DTYPE
[
dtype
]
for
info
in
param_mt
.
values
()
if
(
dtype
:
=
info
.
get
(
"dtype"
,
None
))
and
dtype
in
_SAFETENSORS_TO_TORCH_DTYPE
}
if
param_dtypes
:
config_dict
[
"dtype"
]
=
common_broadcastable_dtype
(
param_dtypes
)
logger
.
info_once
(
"Inferred from consolidated*.safetensors files "
f
"
{
config_dict
[
'dtype'
]
}
dtype."
)
config
=
adapt_config_dict
(
config_dict
,
defaults
=
hf_config_dict
)
config
=
adapt_config_dict
(
config_dict
,
defaults
=
hf_config_dict
)
return
config_dict
,
config
return
config_dict
,
config
...
...
vllm/transformers_utils/configs/mistral.py
View file @
ffbc2e5b
...
@@ -113,12 +113,13 @@ def _remap_mistral_vision_args(config: dict) -> dict:
...
@@ -113,12 +113,13 @@ def _remap_mistral_vision_args(config: dict) -> dict:
def
_remap_mistral_yarn_args
(
config
:
dict
)
->
dict
:
def
_remap_mistral_yarn_args
(
config
:
dict
)
->
dict
:
yarn_config_map
=
{
yarn_config_map
=
{
"factor"
:
"factor"
,
"factor"
:
(
"factor"
,
float
),
"original_max_position_embeddings"
:
"original_max_position_embeddings"
,
"original_max_position_embeddings"
:
(
"original_max_position_embeddings"
,
int
),
"beta"
:
"beta_fast"
,
"beta"
:
(
"beta_fast"
,
float
),
"alpha"
:
"beta_slow"
,
"alpha"
:
(
"beta_slow"
,
float
),
"apply_scale"
:
"apply_yarn_scaling"
,
"apply_scale"
:
(
"apply_yarn_scaling"
,
bool
),
}
}
yarn_config
=
config
.
get
(
"yarn"
)
or
{}
yarn_config
=
config
.
get
(
"yarn"
)
or
{}
config
[
"rope_parameters"
]
=
{
config
[
"rope_parameters"
]
=
{
"rope_type"
:
"yarn"
,
"rope_type"
:
"yarn"
,
...
@@ -128,9 +129,10 @@ def _remap_mistral_yarn_args(config: dict) -> dict:
...
@@ -128,9 +129,10 @@ def _remap_mistral_yarn_args(config: dict) -> dict:
if
rope_theta
:
=
config
.
pop
(
"rope_theta"
,
None
):
if
rope_theta
:
=
config
.
pop
(
"rope_theta"
,
None
):
config
[
"rope_parameters"
][
"rope_theta"
]
=
rope_theta
config
[
"rope_parameters"
][
"rope_theta"
]
=
rope_theta
for
old_name
,
new_name
in
yarn_config_map
.
items
():
for
old_name
,
(
new_name
,
cast
)
in
yarn_config_map
.
items
():
if
old_name
in
yarn_config
:
if
old_name
in
yarn_config
:
config
[
"rope_parameters"
][
new_name
]
=
yarn_config
.
pop
(
old_name
)
# Cast to remove Transformers > v5 type warnings
config
[
"rope_parameters"
][
new_name
]
=
cast
(
yarn_config
.
pop
(
old_name
))
assert
len
(
yarn_config
)
==
0
,
f
"Unparsed yarn config:
{
yarn_config
}
"
assert
len
(
yarn_config
)
==
0
,
f
"Unparsed yarn config:
{
yarn_config
}
"
...
@@ -154,6 +156,7 @@ def _remap_general_mistral_args(config: dict) -> dict:
...
@@ -154,6 +156,7 @@ def _remap_general_mistral_args(config: dict) -> dict:
"tie_word_embeddings"
:
(
"tied_embeddings"
,
False
),
"tie_word_embeddings"
:
(
"tied_embeddings"
,
False
),
"max_seq_len"
:
(
"max_seq_len"
,
config
.
get
(
"max_position_embeddings"
,
128_000
)),
"max_seq_len"
:
(
"max_seq_len"
,
config
.
get
(
"max_position_embeddings"
,
128_000
)),
"max_position_embeddings"
:
(
"max_position_embeddings"
,
128_000
),
"max_position_embeddings"
:
(
"max_position_embeddings"
,
128_000
),
"dtype"
:
(
"dtype"
,
config
.
get
(
"dtype"
)),
}
}
for
key
,
new_key
in
config_mapping
.
items
():
for
key
,
new_key
in
config_mapping
.
items
():
...
...
vllm/transformers_utils/model_arch_config_convertor.py
View file @
ffbc2e5b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterator
from
contextlib
import
contextmanager
from
typing
import
final
from
typing
import
final
import
torch
import
torch
from
huggingface_hub
import
constants
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
from
safetensors.torch
import
_TYPES
as
_SAFETENSORS_TO_TORCH_DTYPE
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
...
@@ -25,22 +22,6 @@ from vllm.utils.torch_utils import common_broadcastable_dtype
...
@@ -25,22 +22,6 @@ from vllm.utils.torch_utils import common_broadcastable_dtype
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
@
contextmanager
def
_maybe_patch_hf_hub_constants
(
config_format
:
ConfigFormat
)
->
Iterator
[
None
]:
if
config_format
==
"mistral"
:
hf_safetensors_single_file
=
constants
.
SAFETENSORS_SINGLE_FILE
hf_safetensors_index_file
=
constants
.
SAFETENSORS_INDEX_FILE
constants
.
SAFETENSORS_SINGLE_FILE
=
"consolidated.safetensors"
constants
.
SAFETENSORS_INDEX_FILE
=
"consolidated.safetensors.index.json"
try
:
yield
finally
:
constants
.
SAFETENSORS_SINGLE_FILE
=
hf_safetensors_single_file
constants
.
SAFETENSORS_INDEX_FILE
=
hf_safetensors_index_file
else
:
yield
class
ModelArchConfigConvertorBase
:
class
ModelArchConfigConvertorBase
:
def
__init__
(
self
,
hf_config
:
PretrainedConfig
,
hf_text_config
:
PretrainedConfig
):
def
__init__
(
self
,
hf_config
:
PretrainedConfig
,
hf_text_config
:
PretrainedConfig
):
self
.
hf_config
=
hf_config
self
.
hf_config
=
hf_config
...
@@ -164,8 +145,7 @@ class ModelArchConfigConvertorBase:
...
@@ -164,8 +145,7 @@ class ModelArchConfigConvertorBase:
# Try to read the dtype of the weights if they are in safetensors format
# Try to read the dtype of the weights if they are in safetensors format
if
config_dtype
is
None
:
if
config_dtype
is
None
:
with
_maybe_patch_hf_hub_constants
(
config_format
):
param_mt
=
get_safetensors_params_metadata
(
model_id
,
revision
=
revision
)
param_mt
=
get_safetensors_params_metadata
(
model_id
,
revision
=
revision
)
if
param_mt
:
if
param_mt
:
param_dtypes
:
set
[
torch
.
dtype
]
=
{
param_dtypes
:
set
[
torch
.
dtype
]
=
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment