Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bdcb42e4
Unverified
Commit
bdcb42e4
authored
Aug 05, 2025
by
Po-Han Huang (NVIDIA)
Committed by
GitHub
Aug 04, 2025
Browse files
[NVIDIA] Auto detect modelopt quant and fix DSR1-FP4 weight loading (#22073)
parent
c09efff9
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
67 additions
and
15 deletions
+67
-15
vllm/config.py
vllm/config.py
+15
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+38
-15
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+14
-0
No files found.
vllm/config.py
View file @
bdcb42e4
...
@@ -1108,6 +1108,21 @@ class ModelConfig:
...
@@ -1108,6 +1108,21 @@ class ModelConfig:
if
quant_cfg
is
None
:
if
quant_cfg
is
None
:
# compressed-tensors uses a "compression_config" key
# compressed-tensors uses a "compression_config" key
quant_cfg
=
getattr
(
self
.
hf_config
,
"compression_config"
,
None
)
quant_cfg
=
getattr
(
self
.
hf_config
,
"compression_config"
,
None
)
else
:
# Set quant_method for ModelOpt models.
producer_name
=
quant_cfg
.
get
(
"producer"
,
{}).
get
(
"name"
)
if
producer_name
==
"modelopt"
:
quant_algo
=
quant_cfg
.
get
(
"quantization"
,
{}).
get
(
"quant_algo"
)
if
quant_algo
==
"FP8"
:
quant_cfg
[
"quant_method"
]
=
"modelopt"
elif
quant_algo
==
"NVFP4"
:
quant_cfg
[
"quant_method"
]
=
"modelopt_fp4"
elif
quant_algo
is
not
None
:
raise
ValueError
(
f
"Unknown ModelOpt quant algo:
{
quant_algo
}
"
)
return
quant_cfg
return
quant_cfg
def
_verify_quantization
(
self
)
->
None
:
def
_verify_quantization
(
self
)
->
None
:
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
bdcb42e4
...
@@ -919,9 +919,13 @@ class FusedMoE(torch.nn.Module):
...
@@ -919,9 +919,13 @@ class FusedMoE(torch.nn.Module):
elif
shard_id
==
"w2"
:
elif
shard_id
==
"w2"
:
param_data
[
expert_id
]
=
loaded_weight
param_data
[
expert_id
]
=
loaded_weight
def
_load_w13_weight_scale
(
self
,
shard_dim
:
int
,
def
_load_
combined_
w13_weight_scale
(
self
,
shard_dim
:
int
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
param
:
torch
.
Tensor
,
tp_rank
:
int
):
param
:
torch
.
Tensor
,
tp_rank
:
int
):
"""
Load w13 weight scales assuming that w1 weight scales and w3 weight
scales are stored in the same loaded_weight tensor.
"""
shard_size
=
param
.
shape
[
shard_dim
]
shard_size
=
param
.
shape
[
shard_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
shard_size
)
...
@@ -1168,24 +1172,43 @@ class FusedMoE(torch.nn.Module):
...
@@ -1168,24 +1172,43 @@ class FusedMoE(torch.nn.Module):
uses_weight_scale_2
=
self
.
quant_method
.
uses_weight_scale_2_pattern
(
uses_weight_scale_2
=
self
.
quant_method
.
uses_weight_scale_2_pattern
(
)
)
# For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale"
# Call _load_per_tensor_weight_scale() to load per-tensor (scalar)
per_tensor_conditions
=
(
# weights scales.
"weight_scale_2"
in
weight_name
if
uses_weight_scale_2
else
# Input scales are always per-tensor.
"weight_scale"
in
weight_name
)
or
"input_scale"
in
weight_name
# Weight scales: FP4 uses "weight_scale_2" and FP8 uses
# "weight_scale" for per-tensor scales.
if
"w13_weight_scale"
in
weight_name
:
is_per_tensor
=
(
"weight_scale_2"
in
weight_name
self
.
_load_w13_weight_scale
(
shard_dim
=
shard_dim
,
if
uses_weight_scale_2
else
"weight_scale"
loaded_weight
=
loaded_weight
,
in
weight_name
)
or
"input_scale"
in
weight_name
param
=
param
,
if
is_per_tensor
:
tp_rank
=
self
.
tp_rank
)
elif
per_tensor_conditions
:
self
.
_load_per_tensor_weight_scale
(
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
param
=
param
,
param
=
param
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
,
expert_id
=
expert_id
,
)
)
elif
"weight"
in
weight_name
:
return
True
if
return_success
else
None
# If the weight is w13_weight_scale and w13_weight_scales are
# combined into single loaded_weight, call
# _load_combined_w13_weight_scale() to load it.
# This is checked by comparing the hidden_out dims of the
# loaded_weight and the param.
if
"w13_weight_scale"
in
weight_name
:
loaded_weight_hidden_out
=
loaded_weight
.
shape
[
-
2
]
param_hidden_out
=
param
.
data
.
shape
[
-
2
]
*
self
.
tp_size
if
loaded_weight_hidden_out
==
param_hidden_out
:
self
.
_load_combined_w13_weight_scale
(
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
param
=
param
,
tp_rank
=
self
.
tp_rank
,
)
return
True
if
return_success
else
None
# For other weights, call _load_model_weight_or_group_weight_scale()
# to load it.
if
"weight"
in
weight_name
:
self
.
_load_model_weight_or_group_weight_scale
(
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
shard_dim
=
shard_dim
,
...
...
vllm/transformers_utils/config.py
View file @
bdcb42e4
...
@@ -449,6 +449,20 @@ def get_config(
...
@@ -449,6 +449,20 @@ def get_config(
model_type
=
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
[
config
.
model_type
]
model_type
=
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
[
config
.
model_type
]
config
.
update
({
"architectures"
:
[
model_type
]})
config
.
update
({
"architectures"
:
[
model_type
]})
# ModelOpt 0.31.0 and after saves the quantization config in the model
# config file.
quantization_config
=
config_dict
.
get
(
"quantization_config"
,
None
)
# ModelOpt 0.29.0 and before saves the quantization config in a separate
# "hf_quant_config.json" in the same directory as the model config file.
if
quantization_config
is
None
\
and
file_or_path_exists
(
model
,
"hf_quant_config.json"
,
revision
):
quantization_config
=
get_hf_file_to_dict
(
"hf_quant_config.json"
,
model
,
revision
)
if
quantization_config
is
not
None
:
config
.
quantization_config
=
quantization_config
if
hf_overrides_kw
:
if
hf_overrides_kw
:
logger
.
debug
(
"Overriding HF config with %s"
,
hf_overrides_kw
)
logger
.
debug
(
"Overriding HF config with %s"
,
hf_overrides_kw
)
config
.
update
(
hf_overrides_kw
)
config
.
update
(
hf_overrides_kw
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment