Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
ff08e519
Unverified
Commit
ff08e519
authored
Jul 30, 2025
by
Po-Han Huang (NVIDIA)
Committed by
GitHub
Jul 30, 2025
Browse files
[NVIDIA] Fix Llama4 Scout FP4 functionality issues (#21499)
Signed-off-by:
Po-Han Huang
<
pohanh@nvidia.com
>
parent
8f4a1c9a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
219 additions
and
70 deletions
+219
-70
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+14
-1
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+0
-2
vllm/model_executor/models/llama4.py
vllm/model_executor/models/llama4.py
+205
-67
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
ff08e519
...
...
@@ -874,6 +874,14 @@ class FusedMoE(torch.nn.Module):
elif
shard_id
==
"w2"
:
param_data
[
expert_id
]
=
loaded_weight
def
_load_w13_weight_scale
(
self
,
shard_dim
:
int
,
loaded_weight
:
torch
.
Tensor
,
param
:
torch
.
Tensor
,
tp_rank
:
int
):
shard_size
=
param
.
shape
[
shard_dim
]
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
param
.
copy_
(
loaded_weight
)
def
_load_model_weight_or_group_weight_scale
(
self
,
shard_dim
:
int
,
expert_data
:
torch
.
Tensor
,
...
...
@@ -1123,7 +1131,12 @@ class FusedMoE(torch.nn.Module):
"weight_scale_2"
in
weight_name
if
uses_weight_scale_2
else
"weight_scale"
in
weight_name
)
or
"input_scale"
in
weight_name
if
per_tensor_conditions
:
if
"w13_weight_scale"
in
weight_name
:
self
.
_load_w13_weight_scale
(
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
param
=
param
,
tp_rank
=
self
.
tp_rank
)
elif
per_tensor_conditions
:
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
param
=
param
,
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
ff08e519
...
...
@@ -778,8 +778,6 @@ class ModelOptNvFp4LinearMethod(LinearMethodBase):
# Swizzle the weight blockscale.
# contracting dimension is input dimension
# block_size = 16;
assert
(
layer
.
weight_scale
.
shape
[
1
]
%
16
==
0
),
(
"Expected weight_scale.dim(1) to be divisible by 16"
)
assert
(
layer
.
weight_scale
.
dtype
==
torch
.
float8_e4m3fn
),
(
"Weight Block scale must be represented as FP8-E4M3"
)
swizzled_weight_scale
=
swizzle_blockscale
(
layer
.
weight_scale
)
...
...
vllm/model_executor/models/llama4.py
View file @
ff08e519
...
...
@@ -342,34 +342,94 @@ class Llama4Model(LlamaModel):
expert_params_mapping
:
list
[
tuple
[
str
,
str
,
int
,
str
]],
fused
:
bool
=
True
,
)
->
bool
:
"""
Load MoE expert weights.
Args:
name: The name of the weight to load.
loaded_weight: The weight to load.
params_dict: The dictionary of module parameters.
loaded_params: The set of already loaded parameters.
expert_params_mapping: The mapping of expert parameters. Must be
generated by FusedMoE.make_expert_params_mapping().
fused: Whether the expert weights are fused into a single weight
tensor or are separate weight tensors for each expert.
When fused is True, loaded_weight should have shape of:
[num_experts, hidden_in, hidden_out] for gate/up/down proj and
[hidden_out, hidden_in] for the others like router.
When fused is False, loaded_weight should have shape of:
[hidden_out, hidden_in].
Returns:
True if loaded_weight is one of MoE weights and the MoE expert
weights are loaded successfully, False otherwise.
"""
# Whether the MoE expert weights are loaded successfully.
expert_param_loaded
=
False
if
"experts.gate_up_proj"
in
name
:
loaded_weight
=
loaded_weight
.
chunk
(
2
,
dim
=-
1
)
# If fused is True, the loaded weight is in the layout of:
# [num_experts, hidden_in, hidden_out], so we must transpose the last
# two dimensions to match the expected layout of the parameters.
if
fused
and
loaded_weight
.
ndim
==
3
:
loaded_weight
=
loaded_weight
.
transpose
(
-
1
,
-
2
)
# If the gate_proj and up_proj weights are fused into a single
# weight tensor, we need to split the weight tensor into a tuple
# of two weight tensors along the hidden_out dimension.
if
"experts.gate_up_proj"
in
name
:
loaded_weight
=
loaded_weight
.
chunk
(
2
,
dim
=-
2
)
# Iterate over all the expert parameters and load the weights if we find
# a match in weight name.
for
(
param_name
,
weight_name
,
expert_id
,
shard_id
)
in
expert_params_mapping
:
# Get a view of the loaded_weight to avoid modifying the original
# one across iterations.
new_loaded_weight
=
loaded_weight
# If expert weights are fused into a single weight tensor, remove
# the expert index from the expected weight name.
if
fused
:
# The string between e_str and proj_str is the expert index.
e_str
,
_
,
proj_str
,
_
=
weight_name
.
split
(
'.'
)
weight_name
=
f
"
{
e_str
}
.
{
proj_str
}
"
param_name
=
f
"
{
param_name
}
weight"
# Skip if the current weight is not one of the MoE weights.
if
weight_name
not
in
name
:
continue
# Replace the weight name with the parameter name.
full_param_name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip layers on other devices.
# Skip if the current weight corresponds to a parameter that
# does not exist on the current PP (pipeline parallel) rank.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Skip if the current weight is for the bias.
if
((
name
.
endswith
(
".bias"
)
or
name
.
endswith
(
"_bias"
))
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
full_param_name
]
weight_loader
=
param
.
weight_loader
if
fused
:
# If the parameter is for w13 together, the corresponding weight
# will be a tuple, so we must select the correct weight
# depending on the shard id, which is either "w1" or "w3".
if
"w13"
in
full_param_name
:
assert
shard_id
in
[
"w1"
,
"w3"
]
shard_idx
=
0
if
shard_id
==
"w1"
else
1
new_loaded_weight
=
new_loaded_weight
[
shard_idx
]
new_loaded_weight
=
new_loaded_weight
.
transpose
(
-
1
,
-
2
)
# If EP (expert parallel) is enabled, update expert_id to the
# starting expert index for the current EP rank and extract the
# corresponding expert weights.
layer_idx
=
extract_layer_index
(
name
)
# EP mapping
expert_map
=
self
.
layers
[
layer_idx
].
feed_forward
.
experts
.
expert_map
if
expert_map
is
not
None
:
...
...
@@ -382,6 +442,9 @@ class Llama4Model(LlamaModel):
else
:
# TODO: add EP support for non fused weights
pass
# Load the weight into the module parameter with corresponding
# shard id and expert id.
weight_loader
(
param
,
new_loaded_weight
,
full_param_name
,
...
...
@@ -390,10 +453,13 @@ class Llama4Model(LlamaModel):
loaded_params
.
add
(
full_param_name
)
expert_param_loaded
=
True
return
expert_param_loaded
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
# Name mapping from the parameter name to the shard name and
# corresponding shard id.
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".qkv_proj"
,
".q_proj"
,
"q"
),
...
...
@@ -402,26 +468,43 @@ class Llama4Model(LlamaModel):
(
".gate_up_proj"
,
".gate_proj"
,
0
),
(
".gate_up_proj"
,
".up_proj"
,
1
),
]
# Indicate whether the expert weights are fused into a single weight
# tensor.
fused_experts_params
=
False
# Expert parameter mapping for the case where the expert weights are
# not fused into a single weight tensor.
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
num_experts
)
# Expert parameter mapping for the case where the expert weights are
# fused into a single weight tensor.
expert_params_mapping_fused
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_up_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"gate_up_proj"
,
num_experts
=
1
)
# All the module parameters.
params_dict
=
dict
(
self
.
named_parameters
())
# The module parameters that have been loaded.
loaded_params
:
set
[
str
]
=
set
()
# Iterate over all the weights and load them into module parameters.
for
name
,
loaded_weight
in
weights
:
# If the name contains "experts.gate_up_proj" or "experts.down_proj"
# without the expert indices, it means the expert weights are fused
# into a single weight tensor across all experts.
if
"experts.gate_up_proj"
in
name
or
"experts.down_proj"
in
name
:
fused_experts_params
=
True
expert_params_mapping
=
expert_params_mapping_fused
# If kv cache quantization scales exist and the weight name
# corresponds to one of the kv cache quantization scales, load
# them.
if
(
self
.
quant_config
is
not
None
and
(
scale_name
:
=
self
.
quant_config
.
get_cache_scale
(
name
))):
# Loading kv cache quantization scales
param
=
params_dict
[
scale_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
@@ -430,84 +513,119 @@ class Llama4Model(LlamaModel):
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
scale_name
)
continue
# Iterate over stacked_params_mapping to check if the current weight
# is one of the stacked parameters. If so, load the weight with the
# corresponding shard id. Note that MoE weights are handled
# separately in the else block.
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
# Skip if the current weight is not one of the stacked
# parameters or if the current weight is a MoE weight.
if
weight_name
not
in
name
or
"experts"
in
name
:
continue
# This check is for ModelOpt ckpts with kv cache quant enabled
# For ModelOpt checkpoints, we need to rename the self_attn
# weight/weight_scale names except for kv cache scales.
if
not
(
name
.
endswith
(
(
".k_scale"
,
".v_scale"
))
and
"self_attn"
in
name
):
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip if the current weight corresponds to a parameter that
# does not exist on the current PP (pipeline parallel) rank.
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
name
.
endswith
(
"scale"
)
and
"expert"
not
in
name
:
# Remapping the name of FP8 kv-scale.
# Remap kv cache scale names for ModelOpt checkpoints.
# TODO: ModelOpt should implement get_cache_scale() such that
# kv cache scale name remapping can be done there.
if
name
.
endswith
(
"scale"
):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
# Load the weight into the module parameter with corresponding
# shard id and exit the for loop and the else block.
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
weight_loader
==
default_weight_loader
:
weight_loader
(
param
,
loaded_weight
)
else
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
# Handle normal (non-stacked) weights and MoE weights.
else
:
moe_loaded
=
self
.
load_moe_expert_weights
(
name
,
loaded_weight
,
params_dict
,
loaded_params
,
expert_params_mapping
,
fused
=
fused_experts_params
)
if
not
moe_loaded
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
# First, try to load MoE weights using load_moe_expert_weights.
# If successful, move on to next loaded weight.
if
self
.
load_moe_expert_weights
(
name
,
loaded_weight
,
params_dict
,
loaded_params
,
expert_params_mapping
,
fused
=
fused_experts_params
):
continue
# Handle flat expert scale parameters that
# don't match per-expert patterns
if
(
"experts."
in
name
and
(
"w13_input_scale"
in
name
or
"w13_weight_scale"
in
name
or
"w2_input_scale"
in
name
or
"w2_weight_scale"
in
name
)):
# These are flat expert scales that apply to all experts
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
# Check for MoE-specific loading support via
# attribute instead of expensive runtime reflection
supports_moe
=
getattr
(
weight_loader
,
'supports_moe_loading'
,
False
)
if
supports_moe
:
# This is a MoE weight loader
if
"w13_"
in
name
:
shard_id
=
"w1"
elif
"w2_"
in
name
:
shard_id
=
"w2"
else
:
shard_id
=
"w1"
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
0
)
else
:
# Regular weight loader (handles both
# param.weight_loader and default_weight_loader)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
continue
# Skip if the current weight corresponds to a parameter that
# does not exist on the current PP (pipeline parallel) rank.
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Handle flat expert scale parameters that don't match
# per-expert patterns, i.e. one weight scale tensor for all
# experts.
scale_names
=
[
"w13_input_scale"
,
"w13_weight_scale"
,
"w2_input_scale"
,
"w2_weight_scale"
]
if
(
"experts."
in
name
and
any
(
scale_name
in
name
for
scale_name
in
scale_names
)):
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
# If weight loader supports special moe loading, use it to
# avoid expensive runtime reflection
if
getattr
(
weight_loader
,
'supports_moe_loading'
,
False
):
# Map the weight name to the corresponding shard id.
shard_id
=
"w2"
if
"w2_"
in
name
else
"w1"
# Transpose if weight scales are FP8 block scales with
# three dimensions:
# [num_experts, hidden_in, hidden_out].
if
name
.
endswith
(
"weight_scale"
)
\
and
loaded_weight
.
dtype
==
torch
.
float8_e4m3fn
\
and
loaded_weight
.
ndim
==
3
:
loaded_weight
=
loaded_weight
.
transpose
(
-
1
,
-
2
)
# Load the weight into the module parameter with
# corresponding shard id and expert id.
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
0
)
else
:
# Regular weight loader (handles both
# param.weight_loader and default_weight_loader)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
continue
# Handle normal (non-stacked, non-MoE) weights.
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
# Finally, return the set of loaded parameters.
return
loaded_params
...
...
@@ -560,23 +678,43 @@ class Llama4ForCausalLM(LlamaForCausalLM):
loaded_weight
:
torch
.
Tensor
,
)
->
tuple
[
str
,
torch
.
Tensor
]:
def
permute
(
w
:
torch
.
Tensor
,
n_heads
:
int
):
# Helper function to permute the weight's channels
def
permute
(
w
:
torch
.
Tensor
,
n_heads
:
int
,
is_weight_scale
:
bool
):
# Calculate the expected shape of the weight.
# Do not rely on w's shape, as it may be in another layout.
attn_in
=
self
.
config
.
head_dim
*
n_heads
attn_out
=
self
.
config
.
hidden_size
# If the weight is FP4 packed as uint8, we need to divide attn_out
# by 2.
if
w
.
dtype
==
torch
.
uint8
and
w
.
shape
[
1
]
*
2
==
attn_out
:
attn_out
=
attn_out
//
2
# If the weight is a weight scale, we need to divide attn_out by
# block size, which is currently 16.
elif
w
.
dtype
==
torch
.
float8_e4m3fn
and
is_weight_scale
\
and
w
.
shape
[
1
]
*
16
==
attn_out
:
attn_out
=
attn_out
//
16
return
w
.
view
(
n_heads
,
attn_in
//
n_heads
//
2
,
2
,
attn_out
).
transpose
(
1
,
2
).
reshape
(
attn_in
,
attn_out
)
modules
=
name
.
split
(
"."
)
# rotary embeds should be sliced
if
(
"wk"
in
modules
or
"k_proj"
in
modules
)
\
and
modules
[
-
1
]
==
"weight"
:
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_key_value_heads
)
elif
(
"wq"
in
modules
or
"q_proj"
in
modules
)
\
and
modules
[
-
1
]
==
"weight"
:
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_attention_heads
)
# Permute Q/K weights and weight block scales for rotary embedding
is_weight
=
modules
[
-
1
]
==
"weight"
is_nvfp4_weight_scale
=
(
modules
[
-
1
]
==
"weight_scale"
and
loaded_weight
.
dtype
==
torch
.
float8_e4m3fn
)
if
is_weight
or
is_nvfp4_weight_scale
:
if
(
"wk"
in
modules
or
"k_proj"
in
modules
):
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_key_value_heads
,
is_nvfp4_weight_scale
)
elif
(
"wq"
in
modules
or
"q_proj"
in
modules
):
loaded_weight
=
permute
(
loaded_weight
,
self
.
config
.
num_attention_heads
,
is_nvfp4_weight_scale
)
return
name
,
loaded_weight
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment