Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4afe687a
Unverified
Commit
4afe687a
authored
Jul 11, 2025
by
Zhiyu
Committed by
GitHub
Jul 11, 2025
Browse files
Enable ModelOpt Llama4 fp8 checkpoint deployment (#20419)
Signed-off-by:
Zhiyu Cheng
<
zhiyuc@nvidia.com
>
parent
5de8d9f1
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
501 additions
and
35 deletions
+501
-35
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+31
-6
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+261
-5
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+10
-0
vllm/model_executor/models/llama4.py
vllm/model_executor/models/llama4.py
+55
-4
vllm/model_executor/models/mllama4.py
vllm/model_executor/models/mllama4.py
+144
-20
No files found.
vllm/model_executor/layers/fused_moe/layer.py
View file @
4afe687a
...
...
@@ -81,6 +81,16 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
raise
NotImplementedError
def
uses_weight_scale_2_pattern
(
self
)
->
bool
:
"""
Returns True if this quantization method uses 'weight_scale_2' pattern
for per-tensor weight scales (e.g., FP4 variants), False otherwise.
This method should be overridden by subclasses that use the
'weight_scale_2' pattern instead of the standard 'weight_scale' pattern.
"""
return
False
@
staticmethod
def
maybe_make_prepare_finalize
(
moe
:
FusedMoEConfig
)
->
Optional
[
FusedMoEPrepareAndFinalize
]:
...
...
@@ -1081,12 +1091,23 @@ class FusedMoE(torch.nn.Module):
# TODO @dsikka: ModelOpt should follow the proper MoE loading pattern
if
"ModelOpt"
in
quant_method_name
:
if
(
'weight_scale_2'
in
weight_name
or
'input_scale'
in
weight_name
):
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
# Determine per-tensor weight scale patterns based on variant
# Use the dedicated method instead of brittle string matching
uses_weight_scale_2
=
self
.
quant_method
.
uses_weight_scale_2_pattern
(
)
# For per-tensor, FP4 uses "weight_scale_2", FP8 uses "weight_scale"
per_tensor_conditions
=
(
"weight_scale_2"
in
weight_name
if
uses_weight_scale_2
else
"weight_scale"
in
weight_name
)
or
"input_scale"
in
weight_name
if
per_tensor_conditions
:
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
param
=
param
,
loaded_weight
=
loaded_weight
,
expert_id
=
expert_id
)
expert_id
=
expert_id
,
)
elif
"weight"
in
weight_name
:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
...
...
@@ -1558,3 +1579,7 @@ direct_register_custom_op(
dispatch_key
=
current_platform
.
dispatch_key
,
tags
=
(
torch
.
Tag
.
needs_fixed_stride_order
,
),
)
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code
FusedMoE
.
weight_loader
.
supports_moe_loading
=
True
# type: ignore[attr-defined]
vllm/model_executor/layers/quantization/modelopt.py
View file @
4afe687a
...
...
@@ -42,9 +42,13 @@ class ModelOptFp8Config(QuantizationConfig):
def
__init__
(
self
,
is_checkpoint_fp8_serialized
:
bool
=
False
,
kv_cache_quant_method
:
Optional
[
str
]
=
None
,
exclude_modules
:
Optional
[
list
[
str
]]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
is_checkpoint_fp8_serialized
=
is_checkpoint_fp8_serialized
self
.
kv_cache_quant_method
=
kv_cache_quant_method
self
.
exclude_modules
=
exclude_modules
if
is_checkpoint_fp8_serialized
:
logger
.
warning
(
"Detected ModelOpt fp8 checkpoint. Please note that"
" the format is experimental and could change."
)
...
...
@@ -69,6 +73,11 @@ class ModelOptFp8Config(QuantizationConfig):
def
from_config
(
cls
,
config
:
dict
[
str
,
Any
])
->
"ModelOptFp8Config"
:
quant_config
=
cls
.
get_from_keys
(
config
,
[
"quantization"
])
quant_method
=
quant_config
[
"quant_algo"
]
kv_cache_quant_method
=
cls
.
get_from_keys
(
config
,
[
"quantization"
]).
get
(
"kv_cache_quant_algo"
)
exclude_modules
=
cls
.
get_from_keys
(
config
,
[
"quantization"
]).
get
(
"exclude_modules"
)
if
quant_method
not
in
QUANT_ALGOS
:
raise
ValueError
(
f
"ModelOpt currently only supports:
{
QUANT_ALGOS
}
"
" quantizations in vLLM. Please check the "
...
...
@@ -76,15 +85,39 @@ class ModelOptFp8Config(QuantizationConfig):
"quant configuration."
)
is_checkpoint_fp8_serialized
=
(
"FP8"
in
quant_method
)
return
cls
(
is_checkpoint_fp8_serialized
)
return
cls
(
is_checkpoint_fp8_serialized
,
kv_cache_quant_method
,
exclude_modules
)
def
is_layer_excluded
(
self
,
prefix
:
str
)
->
bool
:
"""
Check if a layer should be excluded from quantization.
This method handles both regular models and multimodal models that use
the language_model prefix. For multimodal models, it checks if the
module name (without the language_model prefix) is in the exclude list.
"""
if
self
.
exclude_modules
is
None
:
return
False
# Check if any excluded module matches the prefix
for
module
in
self
.
exclude_modules
:
if
(
module
in
prefix
or
(
prefix
.
startswith
(
"language_model."
)
and
module
in
prefix
.
removeprefix
(
"language_model."
))):
return
True
return
False
def
get_quant_method
(
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
from
vllm.attention.layer
import
Attention
# Avoid circular import
if
isinstance
(
layer
,
LinearBase
):
if
self
.
is_layer_excluded
(
prefix
):
return
UnquantizedLinearMethod
()
return
ModelOptFp8LinearMethod
(
self
)
elif
isinstance
(
layer
,
Attention
):
return
ModelOptFp8KVCacheMethod
(
self
)
elif
isinstance
(
layer
,
FusedMoE
):
return
ModelOptFp8MoEMethod
(
self
)
return
None
...
...
@@ -172,6 +205,223 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
bias
=
bias
)
class
ModelOptFp8MoEMethod
(
FusedMoEMethodBase
):
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and
activation scale.
Args:
quant_config: The ModelOpt quantization config.
"""
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
self
.
quant_config
=
quant_config
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
cutlass_fp8_supported
)
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
# Use FP8 dtype if checkpoint is serialized
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
weight_dtype
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
weight_dtype
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
(
num_experts
,
2
),
1.0
,
dtype
=
torch
.
float32
,
),
weight_loader
=
weight_loader
,
)
w2_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,
),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Set weight loader attributes for scales
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
})
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,
),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,
),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
from
vllm._custom_ops
import
scaled_fp8_quant
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
per_tensor_dequantize
)
# Handle scale parameters
if
hasattr
(
layer
,
"w13_weight_scale"
)
and
layer
.
w13_weight_scale
is
not
None
:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales
# then dequant and requant each expert.
if
layer
.
w13_weight_scale
.
dim
()
==
2
:
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
# Requantize each expert's weights using the combined scale
# w13_weight (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size
=
layer
.
w13_weight
.
shape
[
1
]
//
2
for
expert_id
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
start
=
0
for
shard_id
in
range
(
2
):
# w1 and w3
# Dequantize using the original scale for this shard
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
# Requantize using the combined max scale
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:],
_
,
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
intermediate_size
# Update the scale parameter to be per-expert
layer
.
w13_weight_scale
=
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
else
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_weight_scale"
)
and
layer
.
w2_weight_scale
is
not
None
:
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
# Input scales must be equal for each expert in fp8 MoE layers.
if
hasattr
(
layer
,
"w13_input_scale"
)
and
layer
.
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_input_scale"
)
and
layer
.
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
=
False
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for `ModelOptFp8MoEMethod` yet."
)
# Expert selection
topk_weights
,
topk_ids
=
FusedMoE
.
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
custom_routing_function
=
custom_routing_function
,
scoring_func
=
scoring_func
,
e_score_correction_bias
=
e_score_correction_bias
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
(
fused_experts
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
)
class
ModelOptNvFp4Config
(
QuantizationConfig
):
"""Config class for ModelOpt FP4."""
...
...
@@ -472,6 +722,12 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" quantization. Please use Blackwell and"
" above."
)
def
uses_weight_scale_2_pattern
(
self
)
->
bool
:
"""
FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales.
"""
return
True
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
...
...
vllm/model_executor/model_loader/weight_utils.py
View file @
4afe687a
...
...
@@ -762,6 +762,10 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
modelopt_scale_names
=
[
".self_attn.k_proj.k_scale"
,
".self_attn.v_proj.v_scale"
]
# Also support qkv_proj scale parameters (from stacked parameter processing)
qkv_proj_scale_names
=
[
".self_attn.qkv_proj.k_scale"
,
".self_attn.qkv_proj.v_scale"
]
for
scale_name
in
possible_scale_names
:
if
name
.
endswith
(
scale_name
):
if
any
(
mo_scale_name
in
name
...
...
@@ -769,6 +773,12 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
remapped_name
=
name
.
replace
(
f
".self_attn.
{
scale_name
[
1
]
}
_proj
{
scale_name
}
"
,
f
".self_attn.attn
{
scale_name
}
"
)
elif
any
(
qkv_scale_name
in
name
for
qkv_scale_name
in
qkv_proj_scale_names
):
# Handle qkv_proj scale parameters
remapped_name
=
name
.
replace
(
f
".self_attn.qkv_proj
{
scale_name
}
"
,
f
".self_attn.attn
{
scale_name
}
"
)
else
:
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
if
remapped_name
not
in
params_dict
:
...
...
vllm/model_executor/models/llama4.py
View file @
4afe687a
...
...
@@ -35,7 +35,8 @@ from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear
)
from
vllm.model_executor.layers.quantization
import
QuantizationConfig
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
.llama
import
LlamaForCausalLM
,
LlamaMLP
,
LlamaModel
from
.utils
import
(
AutoWeightsLoader
,
extract_layer_index
,
fast_topk
,
...
...
@@ -432,11 +433,23 @@ class Llama4Model(LlamaModel):
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
or
"experts"
in
name
:
continue
# This check is for ModelOpt ckpts with kv cache quant enabled
if
not
(
name
.
endswith
(
(
".k_scale"
,
".v_scale"
))
and
"self_attn"
in
name
):
name
=
name
.
replace
(
weight_name
,
param_name
)
if
is_pp_missing_parameter
(
name
,
self
):
continue
if
name
.
endswith
(
"scale"
)
and
"expert"
not
in
name
:
# Remapping the name of FP8 kv-scale.
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
if
name
is
None
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
if
weight_loader
==
default_weight_loader
:
weight_loader
(
param
,
loaded_weight
)
else
:
weight_loader
(
param
,
loaded_weight
,
shard_id
)
loaded_params
.
add
(
name
)
break
...
...
@@ -452,6 +465,44 @@ class Llama4Model(LlamaModel):
if
not
moe_loaded
:
if
is_pp_missing_parameter
(
name
,
self
):
continue
# Handle flat expert scale parameters that
# don't match per-expert patterns
if
(
"experts."
in
name
and
(
"w13_input_scale"
in
name
or
"w13_weight_scale"
in
name
or
"w2_input_scale"
in
name
or
"w2_weight_scale"
in
name
)):
# These are flat expert scales that apply to all experts
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
# Check for MoE-specific loading support via
# attribute instead of expensive runtime reflection
supports_moe
=
getattr
(
weight_loader
,
'supports_moe_loading'
,
False
)
if
supports_moe
:
# This is a MoE weight loader
if
"w13_"
in
name
:
shard_id
=
"w1"
elif
"w2_"
in
name
:
shard_id
=
"w2"
else
:
shard_id
=
"w1"
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
0
)
else
:
# Regular weight loader (handles both
# param.weight_loader and default_weight_loader)
weight_loader
(
param
,
loaded_weight
)
loaded_params
.
add
(
name
)
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
...
...
vllm/model_executor/models/mllama4.py
View file @
4afe687a
...
...
@@ -717,6 +717,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP
):
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"gate_up_proj"
:
[
"gate_proj"
,
"up_proj"
],
}
@
classmethod
...
...
@@ -902,32 +903,109 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
qkv_weight
=
torch
.
cat
(
weight
,
dim
=
0
)
yield
key
,
qkv_weight
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
def
_rename_weight_for_modelopt_checkpoint
(
self
,
name
:
str
)
->
str
:
"""Rename weights from ModelOpt llama4 fp8 checkpoints to vLLM
format."""
if
name
.
startswith
(
"model."
):
# Handle expert scale parameters with flat naming
if
"feed_forward.experts."
in
name
and
(
"_input_scale"
in
name
or
"_weight_scale"
in
name
):
renamed
=
name
.
replace
(
"model."
,
"language_model.model."
,
1
)
# Map checkpoint naming to vLLM's expected naming
if
"down_proj_input_scale"
in
renamed
:
return
renamed
.
replace
(
"down_proj_input_scale"
,
"w2_input_scale"
)
elif
"down_proj_weight_scale"
in
renamed
:
return
renamed
.
replace
(
"down_proj_weight_scale"
,
"w2_weight_scale"
)
elif
"gate_up_proj_input_scale"
in
renamed
:
return
renamed
.
replace
(
"gate_up_proj_input_scale"
,
"w13_input_scale"
)
elif
"gate_up_proj_weight_scale"
in
renamed
:
return
renamed
.
replace
(
"gate_up_proj_weight_scale"
,
"w13_weight_scale"
)
return
renamed
# Handle attention scale parameters
elif
"self_attn."
in
name
and
(
".k_scale"
in
name
or
".v_scale"
in
name
):
renamed
=
name
.
replace
(
"model."
,
"language_model.model."
,
1
)
if
".k_proj.k_scale"
in
renamed
:
return
renamed
.
replace
(
".k_proj.k_scale"
,
".attn.k_scale"
)
elif
".v_proj.v_scale"
in
renamed
:
return
renamed
.
replace
(
".v_proj.v_scale"
,
".attn.v_scale"
)
return
renamed
# Standard model.* to language_model.model.* renaming
return
name
.
replace
(
"model."
,
"language_model.model."
,
1
)
elif
name
.
startswith
(
"lm_head.weight"
):
return
name
.
replace
(
"lm_head.weight"
,
"language_model.lm_head.weight"
)
return
name
def
_separate_and_rename_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]]
)
->
tuple
[
list
[
tuple
[
str
,
torch
.
Tensor
]],
list
[
tuple
[
str
,
torch
.
Tensor
]]]:
"""Rename weights and separate them into language_model and other
weights."""
language_model_weights
=
[]
other_weights
=
[]
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".self_attn.qkv_proj"
,
".self_attn.q_proj"
,
"q"
),
(
".self_attn.qkv_proj"
,
".self_attn.k_proj"
,
"k"
),
(
".self_attn.qkv_proj"
,
".self_attn.v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
:
set
[
str
]
=
set
()
for
name
,
weight
in
weights
:
renamed
=
self
.
_rename_weight_for_modelopt_checkpoint
(
name
)
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
language_model_weights
,
other_weights
=
self
.
separate_weights
(
weights
,
prefix
=
"language_model."
)
loader
=
AutoWeightsLoader
(
self
)
loaded_language_model_params
=
loader
.
load_weights
(
language_model_weights
)
assert
loaded_language_model_params
is
not
None
updated_params
.
update
(
loaded_language_model_params
)
if
renamed
.
startswith
(
"language_model."
):
language_model_weights
.
append
((
renamed
,
weight
))
else
:
other_weights
.
append
((
renamed
,
weight
))
return
language_model_weights
,
other_weights
def
_handle_expert_scale_broadcasting
(
self
,
weights
:
list
[
tuple
[
str
,
torch
.
Tensor
]],
params_dict
:
dict
)
->
tuple
[
list
[
tuple
[
str
,
torch
.
Tensor
]],
set
[
str
]]:
"""Handle expert scale parameters that need broadcasting.
ModelOpt checkpoints use a single value tensor scalar for BMM style
experts, vLLM expects the scale to be broadcasted across all experts.
"""
regular_weights
=
[]
expert_scale_weights
=
[]
updated_params
=
set
()
for
name
,
weight
in
weights
:
# Check if this is an expert scale parameter that needs broadcasting
if
(
"feed_forward.experts."
in
name
and
"scale"
in
name
and
".shared_expert"
not
in
name
):
if
name
in
params_dict
:
param
=
params_dict
[
name
]
if
(
hasattr
(
param
,
'data'
)
and
param
.
data
.
numel
()
>
1
and
weight
.
numel
()
==
1
):
# Broadcast single value to all experts
param
.
data
.
fill_
(
weight
.
item
())
updated_params
.
add
(
name
)
continue
expert_scale_weights
.
append
((
name
,
weight
))
else
:
regular_weights
.
append
((
name
,
weight
))
return
regular_weights
,
expert_scale_weights
,
updated_params
def
_load_other_weights
(
self
,
other_weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]],
params_dict
:
dict
,
stacked_params_mapping
:
list
)
->
set
[
str
]:
"""Load non-language-model weights with stacking support."""
updated_params
=
set
()
if
self
.
use_data_parallel
:
other_weights
=
self
.
_consolidate_qkv_weights
(
other_weights
)
for
name
,
loaded_weight
in
other_weights
:
# Try stacked parameter mapping first
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
or
self
.
use_data_parallel
:
continue
...
...
@@ -938,10 +1016,56 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Use regular weight loading
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
updated_params
.
add
(
name
)
return
updated_params
def
load_weights
(
self
,
weights
:
Iterable
[
tuple
[
str
,
torch
.
Tensor
]])
->
set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".self_attn.qkv_proj"
,
".self_attn.q_proj"
,
"q"
),
(
".self_attn.qkv_proj"
,
".self_attn.k_proj"
,
"k"
),
(
".self_attn.qkv_proj"
,
".self_attn.v_proj"
,
"v"
),
# Shared expert gate_up_proj stacking
(
".shared_expert.gate_up_proj"
,
".shared_expert.gate_proj"
,
0
),
(
".shared_expert.gate_up_proj"
,
".shared_expert.up_proj"
,
1
),
# Feed forward gate_up_proj stacking (for non-MoE layers if any)
(
".feed_forward.gate_up_proj"
,
".feed_forward.gate_proj"
,
0
),
(
".feed_forward.gate_up_proj"
,
".feed_forward.up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
updated_params
:
set
[
str
]
=
set
()
# Separate and rename weights
language_model_weights
,
other_weights
=
(
self
.
_separate_and_rename_weights
(
weights
))
# Handle expert scale parameters
regular_weights
,
expert_scale_weights
,
updated_params_from_experts
=
(
self
.
_handle_expert_scale_broadcasting
(
language_model_weights
,
params_dict
))
updated_params
.
update
(
updated_params_from_experts
)
loader
=
AutoWeightsLoader
(
self
)
loaded_language_model_params
=
loader
.
load_weights
(
regular_weights
)
assert
loaded_language_model_params
is
not
None
updated_params
.
update
(
loaded_language_model_params
)
if
expert_scale_weights
:
loaded_expert_scale_params
=
loader
.
load_weights
(
expert_scale_weights
)
if
loaded_expert_scale_params
:
updated_params
.
update
(
loaded_expert_scale_params
)
updated_params
.
update
(
self
.
_load_other_weights
(
other_weights
,
params_dict
,
stacked_params_mapping
))
return
updated_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment