Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4afe687a
Unverified
Commit
4afe687a
authored
Jul 11, 2025
by
Zhiyu
Committed by
GitHub
Jul 11, 2025
Browse files
Enable ModelOpt Llama4 fp8 checkpoint deployment (#20419)
Signed-off-by:
Zhiyu Cheng
<
zhiyuc@nvidia.com
>
parent
5de8d9f1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
501 additions
and
35 deletions
+501
-35
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+31
-6
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+261
-5
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+10
-0
vllm/model_executor/models/llama4.py
vllm/model_executor/models/llama4.py
+55
-4
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+144
-20
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
4afe687a
...
@@ -81,6 +81,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
...
@@ -81,6 +81,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
raise
NotImplementedError
raise
NotImplementedError
def
uses_weight_scale_2_pattern
(
self
)
->
bool
:
"""
Returns True if this quantization method uses 'weight_scale_2' pattern
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
This method should be overridden by subclasses that use the
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
"""
return
False
@
staticmethod
@
staticmethod
def
maybe_make_prepare_finalize
(
def
maybe_make_prepare_finalize
(
moe
:
FusedMoEConfig
)
->
Optional
[
FusedMoEPrepareAndFinalize
]:
moe
:
FusedMoEConfig
)
->
Optional
[
FusedMoEPrepareAndFinalize
]:
...
@@ -1081,12 +1091,23 @@ class FusedMoE(torch.nn.Module):
...
@@ -1081,12 +1091,23 @@ class FusedMoE(torch.nn.Module):
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if
"ModelOpt"
in
quant_method_name
:
if
"ModelOpt"
in
quant_method_name
:
if
(
'weight_scale_2'
in
weight_name
# Determine per-tensor weight scale patterns based on variant
or
'input_scale'
in
weight_name
):
# Use the dedicated method instead of brittle string matching
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
uses_weight_scale_2
=
self
.
quant_method
.
uses_weight_scale_2_pattern
(
param
=
param
,
)
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
# For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale"
per_tensor_conditions
=
(
"weight_scale_2"
in
weight_name
if
uses_weight_scale_2
else
"weight_scale"
in
weight_name
)
or
"input_scale"
in
weight_name
if
per_tensor_conditions
:
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
,
)
elif
"weight"
in
weight_name
:
elif
"weight"
in
weight_name
:
self
.
_load_model_weight_or_group_weight_scale
(
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
...
@@ -1558,3 +1579,7 @@ direct_register_custom_op(
...
@@ -1558,3 +1579,7 @@ direct_register_custom_op(
dispatch_key
=
current_platform
.
dispatch_key
,
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
)
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code
FusedMoE
.
weight_loader
.
supports_moe_loading
=
True
# type: ignore[attr-defined]
vllm/model_executor/layers/quantization/modelopt.py
View file @
4afe687a
...
@@ -42,9 +42,13 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -42,9 +42,13 @@ class ModelOptFp8Config(QuantizationConfig):
def
__init__
(
def
__init__
(
self
,
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
kv_cache_quant_method
:
Optional
[
str
]
=
None
,
exclude_modules
:
Optional
[
list
[
str
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
kv_cache_quant_method
=
kv_cache_quant_method
self
.
exclude_modules
=
exclude_modules
if
is_checkpoint_fp8_serialized
:
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected ModelOpt fp8 checkpoint. Please note that"
logger
.
warning
(
"Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change."
)
" the format is experimental and could change."
)
...
@@ -69,6 +73,11 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -69,6 +73,11 @@ class ModelOptFp8Config(QuantizationConfig):
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"ModelOptFp8Config"
:
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"ModelOptFp8Config"
:
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
quant_method
=
quant_config
[
"quant_algo"
]
kv_cache_quant_method
=
cls
.
get_from_keys
(
config
,
[
"quantization"
]).
get
(
"kv_cache_quant_algo"
)
exclude_modules
=
cls
.
get_from_keys
(
config
,
[
"quantization"
]).
get
(
"exclude_modules"
)
if
quant_method
not
in
QUANT_ALGOS
:
if
quant_method
not
in
QUANT_ALGOS
:
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
" quantizations in vLLM. Please check the "
" quantizations in vLLM. Please check the "
...
@@ -76,27 +85,51 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -76,27 +85,51 @@ class ModelOptFp8Config(QuantizationConfig):
"quant configuration."
)
"quant configuration."
)
is_checkpoint_fp8_serialized
=
(
"FP8"
in
quant_method
)
is_checkpoint_fp8_serialized
=
(
"FP8"
in
quant_method
)
return
cls
(
is_checkpoint_fp8_serialized
)
return
cls
(
is_checkpoint_fp8_serialized
,
kv_cache_quant_method
,
exclude_modules
)
def
is_layer_excluded
(
self
,
prefix
:
str
)
->
bool
:
"""
Check if a layer should be excluded from quantization.
This method handles both regular models and multimodal models that use
the language_model prefix. For multimodal models, it checks if the
module name (without the language_model prefix) is in the exclude list.
"""
if
self
.
exclude_modules
is
None
:
return
False
# Check if any excluded module matches the prefix
for
module
in
self
.
exclude_modules
:
if
(
module
in
prefix
or
(
prefix
.
startswith
(
"language_model."
)
and
module
in
prefix
.
removeprefix
(
"language_model."
))):
return
True
return
False
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
isinstance
(
layer
,
LinearBase
):
if
self
.
is_layer_excluded
(
prefix
):
return
UnquantizedLinearMethod
()
return
ModelOptFp8LinearMethod
(
self
)
return
ModelOptFp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
Attention
):
elif
isinstance
(
layer
,
Attention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
ModelOptFp8KVCacheMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
ModelOptFp8MoEMethod
(
self
)
return
None
return
None
class
ModelOptFp8LinearMethod
(
LinearMethodBase
):
class
ModelOptFp8LinearMethod
(
LinearMethodBase
):
"""Linear method for Model Optimizer static quantization.
"""Linear method for Model Optimizer static quantization.
Supports loading FP8 checkpoints with static weight scale and
Supports loading FP8 checkpoints with static weight scale and
activation scale. Future support might be added for dynamic
activation scale. Future support might be added for dynamic
scales.
scales.
Limitations:
Limitations:
1. Only support per-tensor quantization due to torch._scaled_mm support.
1. Only support per-tensor quantization due to torch._scaled_mm support.
2. Only support float8_e4m3fn datatype
2. Only support float8_e4m3fn datatype
Args: quant_config: The ModelOpt quantization config.
Args: quant_config: The ModelOpt quantization config.
"""
"""
...
@@ -172,6 +205,223 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -172,6 +205,223 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
bias
=
bias
)
bias
=
bias
)
class
ModelOptFp8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and
activation scale.
Args:
quant_config: The ModelOpt quantization config.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
self
.
quant_config
=
quant_config
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_fp8_supported
)
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Use FP8 dtype if checkpoint is serialized
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
weight_dtype
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
weight_dtype
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
(
num_experts
,
2
),
1.0
,
dtype
=
torch
.
float32
,
),
weight_loader
=
weight_loader
,
)
w2_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,
),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Set weight loader attributes for scales
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,
),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,
),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
from
vllm._custom_ops
import
scaled_fp8_quant
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
per_tensor_dequantize
)
# Handle scale parameters
if
hasattr
(
layer
,
"w13_weight_scale"
)
and
layer
.
w13_weight_scale
is
not
None
:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales
# then dequant and requant each expert.
if
layer
.
w13_weight_scale
.
dim
()
==
2
:
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
# Requantize each expert's weights using the combined scale
# w13_weight (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size
=
layer
.
w13_weight
.
shape
[
1
]
//
2
for
expert_id
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
start
=
0
for
shard_id
in
range
(
2
):
# w1 and w3
# Dequantize using the original scale for this shard
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
# Requantize using the combined max scale
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:],
_
,
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
intermediate_size
# Update the scale parameter to be per-expert
layer
.
w13_weight_scale
=
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
else
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_weight_scale"
)
and
layer
.
w2_weight_scale
is
not
None
:
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
# Input scales must be equal for each expert in fp8 MoE layers.
if
hasattr
(
layer
,
"w13_input_scale"
)
and
layer
.
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_input_scale"
)
and
layer
.
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
class
ModelOptNvFp4Config
(
QuantizationConfig
):
class
ModelOptNvFp4Config
(
QuantizationConfig
):
"""Config class for ModelOpt FP4."""
"""Config class for ModelOpt FP4."""
...
@@ -274,7 +524,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
...
@@ -274,7 +524,7 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
class
ModelOptNvFp4LinearMethod
(
LinearMethodBase
):
class
ModelOptNvFp4LinearMethod
(
LinearMethodBase
):
"""Linear method for Model Optimizer NVFP4.
"""Linear method for Model Optimizer NVFP4.
Supports loading NVFP4 checkpoints with the following structure:
Supports loading NVFP4 checkpoints with the following structure:
input_scale: torch.float32, scalar ,
input_scale: torch.float32, scalar ,
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
weight: NVFP4(represented as byte) Shape: [1, X, y/2]
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale,
...
@@ -455,7 +705,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
...
@@ -455,7 +705,7 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
class
ModelOptNvFp4FusedMoE
(
FusedMoEMethodBase
):
class
ModelOptNvFp4FusedMoE
(
FusedMoEMethodBase
):
"""
"""
MoE Method for FP4 Quantization.
MoE Method for FP4 Quantization.
Args:
Args:
quant_config: NVFP4 Quant Config
quant_config: NVFP4 Quant Config
"""
"""
...
@@ -472,6 +722,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
...
@@ -472,6 +722,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" quantization. Please use Blackwell and"
" quantization. Please use Blackwell and"
" above."
)
" above."
)
def
uses_weight_scale_2_pattern
(
self
)
->
bool
:
"""
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
"""
return
True
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
4afe687a
...
@@ -762,6 +762,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
...
@@ -762,6 +762,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
modelopt_scale_names
=
[
modelopt_scale_names
=
[
".self_attn.k_proj.k_scale"
,
".self_attn.v_proj.v_scale"
".self_attn.k_proj.k_scale"
,
".self_attn.v_proj.v_scale"
]
]
# Also support qkv_proj scale parameters (from stacked parameter processing)
qkv_proj_scale_names
=
[
".self_attn.qkv_proj.k_scale"
,
".self_attn.qkv_proj.v_scale"
]
for
scale_name
in
possible_scale_names
:
for
scale_name
in
possible_scale_names
:
if
name
.
endswith
(
scale_name
):
if
name
.
endswith
(
scale_name
):
if
any
(
mo_scale_name
in
name
if
any
(
mo_scale_name
in
name
...
@@ -769,6 +773,12 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
...
@@ -769,6 +773,12 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
remapped_name
=
name
.
replace
(
remapped_name
=
name
.
replace
(
f
".self_attn.
{
scale_name
[
1
]
}
_proj
{
scale_name
}
"
,
f
".self_attn.
{
scale_name
[
1
]
}
_proj
{
scale_name
}
"
,
f
".self_attn.attn
{
scale_name
}
"
)
f
".self_attn.attn
{
scale_name
}
"
)
elif
any
(
qkv_scale_name
in
name
for
qkv_scale_name
in
qkv_proj_scale_names
):
# Handle qkv_proj scale parameters
remapped_name
=
name
.
replace
(
f
".self_attn.qkv_proj
{
scale_name
}
"
,
f
".self_attn.attn
{
scale_name
}
"
)
else
:
else
:
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
if
remapped_name
not
in
params_dict
:
if
remapped_name
not
in
params_dict
:
...
...
vllm/model_executor/models/llama4.py
View file @
4afe687a
...
@@ -35,7 +35,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
...
@@ -35,7 +35,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
.llama
import
LlamaForCausalLM
,
LlamaMLP
,
LlamaModel
from
.llama
import
LlamaForCausalLM
,
LlamaMLP
,
LlamaModel
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
fast_topk
,
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
fast_topk
,
...
@@ -432,12 +433,24 @@ class Llama4Model(LlamaModel):
...
@@ -432,12 +433,24 @@ class Llama4Model(LlamaModel):
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
or
"experts"
in
name
:
if
weight_name
not
in
name
or
"experts"
in
name
:
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# This check is for ModelOpt ckpts with kv cache quant enabled
if
not
(
name
.
endswith
(
(
".k_scale"
,
".v_scale"
))
and
"self_attn"
in
name
):
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
if
name
.
endswith
(
"scale"
)
and
"expert"
not
in
name
:
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
(
param
,
loaded_weight
,
shard_id
)
default_weight_loader
)
if
weight_loader
==
default_weight_loader
:
weight_loader
(
param
,
loaded_weight
)
else
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
loaded_params
.
add
(
name
)
break
break
else
:
else
:
...
@@ -452,6 +465,44 @@ class Llama4Model(LlamaModel):
...
@@ -452,6 +465,44 @@ class Llama4Model(LlamaModel):
if
not
moe_loaded
:
if
not
moe_loaded
:
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
# Handle flat expert scale parameters that
# don't match per-expert patterns
if
(
"experts."
in
name
and
(
"w13_input_scale"
in
name
or
"w13_weight_scale"
in
name
or
"w2_input_scale"
in
name
or
"w2_weight_scale"
in
name
)):
# These are flat expert scales that apply to all experts
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
# Check for MoE-specific loading support via
# attribute instead of expensive runtime reflection
supports_moe
=
getattr
(
weight_loader
,
'supports_moe_loading'
,
False
)
if
supports_moe
:
# This is a MoE weight loader
if
"w13_"
in
name
:
shard_id
=
"w1"
elif
"w2_"
in
name
:
shard_id
=
"w2"
else
:
shard_id
=
"w1"
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
0
)
else
:
# Regular weight loader (handles both
# param.weight_loader and default_weight_loader)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
continue
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/mllama4.py
View file @
4afe687a
...
@@ -717,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -717,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP
):
SupportsPP
):
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
}
}
@
classmethod
@
classmethod
...
@@ -902,32 +903,109 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -902,32 +903,109 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
qkv_weight
=
torch
.
cat
(
weight
,
dim
=
0
)
qkv_weight
=
torch
.
cat
(
weight
,
dim
=
0
)
yield
key
,
qkv_weight
yield
key
,
qkv_weight
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
def
_rename_weight_for_modelopt_checkpoint
(
self
,
name
:
str
)
->
str
:
torch
.
Tensor
]])
->
set
[
str
]:
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
format."""
if
name
.
startswith
(
"model."
):
# Handle expert scale parameters with flat naming
if
"feed_forward.experts."
in
name
and
(
"_input_scale"
in
name
or
"_weight_scale"
in
name
):
renamed
=
name
.
replace
(
"model."
,
"language_model.model."
,
1
)
# Map checkpoint naming to vLLM's expected naming
if
"down_proj_input_scale"
in
renamed
:
return
renamed
.
replace
(
"down_proj_input_scale"
,
"w2_input_scale"
)
elif
"down_proj_weight_scale"
in
renamed
:
return
renamed
.
replace
(
"down_proj_weight_scale"
,
"w2_weight_scale"
)
elif
"gate_up_proj_input_scale"
in
renamed
:
return
renamed
.
replace
(
"gate_up_proj_input_scale"
,
"w13_input_scale"
)
elif
"gate_up_proj_weight_scale"
in
renamed
:
return
renamed
.
replace
(
"gate_up_proj_weight_scale"
,
"w13_weight_scale"
)
return
renamed
# Handle attention scale parameters
elif
"self_attn."
in
name
and
(
".k_scale"
in
name
or
".v_scale"
in
name
):
renamed
=
name
.
replace
(
"model."
,
"language_model.model."
,
1
)
if
".k_proj.k_scale"
in
renamed
:
return
renamed
.
replace
(
".k_proj.k_scale"
,
".attn.k_scale"
)
elif
".v_proj.v_scale"
in
renamed
:
return
renamed
.
replace
(
".v_proj.v_scale"
,
".attn.v_scale"
)
return
renamed
# Standard model.* to language_model.model.* renaming
return
name
.
replace
(
"model."
,
"language_model.model."
,
1
)
elif
name
.
startswith
(
"lm_head.weight"
):
return
name
.
replace
(
"lm_head.weight"
,
"language_model.lm_head.weight"
)
return
name
def
_separate_and_rename_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]
)
->
tuple
[
list
[
tuple
[
str
,
torch
.
Tensor
]],
list
[
tuple
[
str
,
torch
.
Tensor
]]]:
"""Rename weights and separate them into language_model and other
weights."""
language_model_weights
=
[]
other_weights
=
[]
stacked_params_mapping
=
[
for
name
,
weight
in
weights
:
# (param_name, shard_name, shard_id)
renamed
=
self
.
_rename_weight_for_modelopt_checkpoint
(
name
)
(
".self_attn.qkv_proj"
,
".self_attn.q_proj"
,
"q"
),
(
".self_attn.qkv_proj"
,
".self_attn.k_proj"
,
"k"
),
(
".self_attn.qkv_proj"
,
".self_attn.v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
:
set
[
str
]
=
set
()
# language_model is an Llama4ForCausalLM instance. We load it's
if
renamed
.
startswith
(
"language_model."
):
# using llama4's load_weights routine.
language_model_weights
.
append
((
renamed
,
weight
))
language_model_weights
,
other_weights
=
self
.
separate_weights
(
else
:
weights
,
prefix
=
"language_model."
)
other_weights
.
append
((
renamed
,
weight
))
loader
=
AutoWeightsLoader
(
self
)
loaded_language_model_params
=
loader
.
load_weights
(
return
language_model_weights
,
other_weights
language_model_weights
)
assert
loaded_language_model_params
is
not
None
def
_handle_expert_scale_broadcasting
(
updated_params
.
update
(
loaded_language_model_params
)
self
,
weights
:
list
[
tuple
[
str
,
torch
.
Tensor
]],
params_dict
:
dict
)
->
tuple
[
list
[
tuple
[
str
,
torch
.
Tensor
]],
set
[
str
]]:
"""Handle expert scale parameters that need broadcasting.
ModelOpt checkpoints use a single value tensor scalar for BMM style
experts, vLLM expects the scale to be broadcasted across all experts.
"""
regular_weights
=
[]
expert_scale_weights
=
[]
updated_params
=
set
()
for
name
,
weight
in
weights
:
# Check if this is an expert scale parameter that needs broadcasting
if
(
"feed_forward.experts."
in
name
and
"scale"
in
name
and
".shared_expert"
not
in
name
):
if
name
in
params_dict
:
param
=
params_dict
[
name
]
if
(
hasattr
(
param
,
'data'
)
and
param
.
data
.
numel
()
>
1
and
weight
.
numel
()
==
1
):
# Broadcast single value to all experts
param
.
data
.
fill_
(
weight
.
item
())
updated_params
.
add
(
name
)
continue
expert_scale_weights
.
append
((
name
,
weight
))
else
:
regular_weights
.
append
((
name
,
weight
))
return
regular_weights
,
expert_scale_weights
,
updated_params
def
_load_other_weights
(
self
,
other_weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
params_dict
:
dict
,
stacked_params_mapping
:
list
)
->
set
[
str
]:
"""Load non-language-model weights with stacking support."""
updated_params
=
set
()
if
self
.
use_data_parallel
:
if
self
.
use_data_parallel
:
other_weights
=
self
.
_consolidate_qkv_weights
(
other_weights
)
other_weights
=
self
.
_consolidate_qkv_weights
(
other_weights
)
for
name
,
loaded_weight
in
other_weights
:
for
name
,
loaded_weight
in
other_weights
:
# Try stacked parameter mapping first
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
or
self
.
use_data_parallel
:
if
weight_name
not
in
name
or
self
.
use_data_parallel
:
continue
continue
...
@@ -938,10 +1016,56 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
...
@@ -938,10 +1016,56 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader
(
param
,
loaded_weight
,
shard_id
)
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
break
else
:
else
:
# Use regular weight loading
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
updated_params
.
add
(
name
)
updated_params
.
add
(
name
)
return
updated_params
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".self_attn.qkv_proj"
,
".self_attn.q_proj"
,
"q"
),
(
".self_attn.qkv_proj"
,
".self_attn.k_proj"
,
"k"
),
(
".self_attn.qkv_proj"
,
".self_attn.v_proj"
,
"v"
),
# Shared expert gate_up_proj stacking
(
".shared_expert.gate_up_proj"
,
".shared_expert.gate_proj"
,
0
),
(
".shared_expert.gate_up_proj"
,
".shared_expert.up_proj"
,
1
),
# Feed forward gate_up_proj stacking (for non-MoE layers if any)
(
".feed_forward.gate_up_proj"
,
".feed_forward.gate_proj"
,
0
),
(
".feed_forward.gate_up_proj"
,
".feed_forward.up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
:
set
[
str
]
=
set
()
# Separate and rename weights
language_model_weights
,
other_weights
=
(
self
.
_separate_and_rename_weights
(
weights
))
# Handle expert scale parameters
regular_weights
,
expert_scale_weights
,
updated_params_from_experts
=
(
self
.
_handle_expert_scale_broadcasting
(
language_model_weights
,
params_dict
))
updated_params
.
update
(
updated_params_from_experts
)
loader
=
AutoWeightsLoader
(
self
)
loaded_language_model_params
=
loader
.
load_weights
(
regular_weights
)
assert
loaded_language_model_params
is
not
None
updated_params
.
update
(
loaded_language_model_params
)
if
expert_scale_weights
:
loaded_expert_scale_params
=
loader
.
load_weights
(
expert_scale_weights
)
if
loaded_expert_scale_params
:
updated_params
.
update
(
loaded_expert_scale_params
)
updated_params
.
update
(
self
.
_load_other_weights
(
other_weights
,
params_dict
,
stacked_params_mapping
))
return
updated_params
return
updated_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment