Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
659907e3
Unverified
Commit
659907e3
authored
Jul 08, 2025
by
Zhiyu
Committed by
GitHub
Jul 08, 2025
Browse files
Enable ModelOpt Llama4 fp8 checkpoint deployment in SGLang (#7129)
parent
cb9d91ea
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
643 additions
and
81 deletions
+643
-81
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+39
-1
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+244
-1
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+360
-79
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
659907e3
...
@@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module):
...
@@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module):
loaded_weight
:
torch
.
tensor
,
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
,
tp_rank
:
int
,
):
):
"""Load w2 weights for down projection.
Args:
expert_data: The expert data tensor to load into
shard_dim: The dimension to shard along
shard_id: The shard ID (must be "w2")
loaded_weight: The weight tensor to load from
tp_rank: The tensor parallel rank
"""
if
not
isinstance
(
expert_data
,
torch
.
Tensor
)
or
not
isinstance
(
loaded_weight
,
torch
.
Tensor
):
raise
ValueError
(
"expert_data and loaded_weight must be torch.Tensor"
)
if
expert_data
.
dim
()
!=
2
or
loaded_weight
.
dim
()
!=
2
:
raise
ValueError
(
f
"Expected 2D tensors, got expert_data shape
{
expert_data
.
shape
}
and loaded_weight shape
{
loaded_weight
.
shape
}
"
)
if
shard_id
!=
"w2"
:
raise
ValueError
(
f
"shard_id must be 'w2', got
{
shard_id
}
"
)
# Index the loaded weight for tp sharding.
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# down_proj: "RowParallel" so tp sharding on input_dim
...
@@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module):
if
not
self
.
use_presharded_weights
:
if
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
if
shard_size
*
tp_rank
+
shard_size
>
loaded_weight
.
shape
[
shard_dim
]:
raise
ValueError
(
f
"Shard size
{
shard_size
}
at rank
{
tp_rank
}
exceeds loaded_weight dimension
{
loaded_weight
.
shape
[
shard_dim
]
}
"
)
loaded_weight
=
loaded_weight
.
narrow
(
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
)
...
@@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module):
...
@@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module):
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
)
)
return
return
if
"ModelOpt"
in
self
.
quant_method
.
__class__
.
__name__
:
if
"ModelOpt"
in
self
.
quant_method
.
__class__
.
__name__
:
if
"weight_scale_2"
in
weight_name
or
"input_scale"
in
weight_name
:
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant
=
(
"ModelOptNvFp4FusedMoEMethod"
in
self
.
quant_method
.
__class__
.
__name__
)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions
=
(
"weight_scale_2"
in
weight_name
if
is_fp4_variant
else
"weight_scale"
in
weight_name
)
or
"input_scale"
in
weight_name
if
per_tensor_conditions
:
self
.
_load_per_tensor_weight_scale
(
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
param
=
param
,
param
=
param
,
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
659907e3
...
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
...
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from
sglang.srt.layers.quantization.utils
import
(
from
sglang.srt.layers.quantization.utils
import
(
convert_to_channelwise
,
convert_to_channelwise
,
is_layer_skipped
,
is_layer_skipped
,
per_tensor_dequantize
,
requantize_with_max_scale
,
requantize_with_max_scale
,
)
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
)
->
Optional
[
"QuantizeMethodBase"
]:
if
self
.
exclude_modules
and
any
(
if
self
.
exclude_modules
and
any
(
module
in
prefix
for
module
in
self
.
exclude_modules
module
in
prefix
or
(
prefix
.
startswith
(
"language_model."
)
and
module
in
prefix
.
removeprefix
(
"language_model."
)
)
for
module
in
self
.
exclude_modules
):
):
return
None
return
None
...
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
...
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
if
self
.
kv_cache_quant_method
and
isinstance
(
layer
,
RadixAttention
):
if
self
.
kv_cache_quant_method
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
return
ModelOptFp8KVCacheMethod
(
self
)
# Add MoE support
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
FusedMoE
):
return
ModelOptFp8MoEMethod
(
self
)
return
None
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
...
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
super
().
__init__
(
quant_config
)
super
().
__init__
(
quant_config
)
class
ModelOptFp8MoEMethod
:
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and activation scale.
Args:
quant_config: The ModelOpt quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
"""
Dynamic class composition pattern.
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
at runtime while avoiding circular import issues.
"""
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoEMethodBase
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
# Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
weight_dtype
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
weight_dtype
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
(
num_experts
,
2
),
torch
.
finfo
(
torch
.
float32
).
min
,
dtype
=
torch
.
float32
,
),
weight_loader
=
weight_loader
,
)
w2_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
(
num_experts
,),
torch
.
finfo
(
torch
.
float32
).
min
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Set weight loader attributes for scales
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
# Handle scale parameters
if
hasattr
(
layer
,
"w13_weight_scale"
)
and
layer
.
w13_weight_scale
is
not
None
:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales then dequant and requant each expert.
if
layer
.
w13_weight_scale
.
dim
()
==
2
:
# Shape: (num_experts, 2)
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
# Requantize each expert's weights using the combined scale
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size
=
layer
.
w13_weight
.
shape
[
1
]
//
2
for
expert_id
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
start
=
0
for
shard_id
in
range
(
2
):
# w1 and w3
# Dequantize using the original scale for this shard
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:
],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
# Requantize using the combined max scale
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:
],
_
,
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
intermediate_size
# Update the scale parameter to be per-expert instead of per-shard
layer
.
w13_weight_scale
=
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
else
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_weight_scale"
)
and
layer
.
w2_weight_scale
is
not
None
:
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
if
hasattr
(
layer
,
"w13_input_scale"
)
and
layer
.
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_input_scale"
)
and
layer
.
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
# Expert selection
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
# ModelOpt uses per-tensor quantization
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
no_combine
=
no_combine
,
)
class
ModelOptFp4Config
(
QuantizationConfig
):
class
ModelOptFp4Config
(
QuantizationConfig
):
"""Config class for FP4."""
"""Config class for FP4."""
...
...
python/sglang/srt/models/mllama4.py
View file @
659907e3
import
json
as
json_lib
import
logging
import
os
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
List
,
Optional
,
Set
,
Tuple
from
typing
import
List
,
Optional
,
Set
,
Tuple
...
@@ -19,6 +22,13 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
...
@@ -19,6 +22,13 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
is_cpu
from
sglang.srt.utils
import
add_prefix
,
is_cpu
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
class
Llama4ForConditionalGeneration
(
nn
.
Module
):
class
Llama4ForConditionalGeneration
(
nn
.
Module
):
...
@@ -37,19 +47,85 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -37,19 +47,85 @@ class Llama4ForConditionalGeneration(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
vision_model
=
Llama4VisionModel
(
config
.
vision_config
)
# Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
self
.
multi_modal_projector
=
Llama4MultiModalProjector
(
config
)
self
.
has_vision
=
self
.
_has_vision_weights
(
config
)
if
not
self
.
has_vision
:
logger
.
warning
(
"No vision weights found in checkpoint. Model will run in text-only mode. "
"Multimodal capabilities (image processing) will be unavailable."
)
if
self
.
has_vision
:
self
.
vision_model
=
Llama4VisionModel
(
config
.
vision_config
)
self
.
multi_modal_projector
=
Llama4MultiModalProjector
(
config
)
else
:
self
.
vision_model
=
None
self
.
multi_modal_projector
=
None
# Initialize the language model
# Initialize the language model
from
sglang.srt.models.llama4
import
Llama4ForCausalLM
from
sglang.srt.models.llama4
import
Llama4ForCausalLM
self
.
language_model
=
Llama4ForCausalLM
(
self
.
language_model
=
Llama4ForCausalLM
(
config
.
text_config
,
config
.
text_config
if
hasattr
(
config
,
"text_config"
)
else
config
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"language_model"
,
prefix
),
prefix
=
add_prefix
(
"language_model"
,
prefix
),
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
if
hasattr
(
config
,
"text_config"
)
else
config
)
def
_has_vision_weights
(
self
,
config
)
->
bool
:
"""Check if the model has vision components by examining the checkpoint."""
model_path
=
getattr
(
config
,
"_name_or_path"
,
None
)
if
not
model_path
:
return
False
# Check if this is a local path first
if
os
.
path
.
isdir
(
model_path
):
index_file
=
os
.
path
.
join
(
model_path
,
"model.safetensors.index.json"
)
if
os
.
path
.
exists
(
index_file
):
return
self
.
_check_vision_weights_in_index
(
index_file
)
# For HuggingFace models, we need to check the actual checkpoint
# The config might say it's multimodal, but the checkpoint might be text-only
try
:
# Try to access the HuggingFace cache directory
from
huggingface_hub
import
try_to_load_from_cache
# Check if index file exists in cache
index_file_path
=
try_to_load_from_cache
(
repo_id
=
model_path
,
filename
=
"model.safetensors.index.json"
,
cache_dir
=
None
,
)
if
index_file_path
and
os
.
path
.
exists
(
index_file_path
):
return
self
.
_check_vision_weights_in_index
(
index_file_path
)
except
Exception
:
# If we can't access the cache, fall back to config-based detection
pass
# Fallback, assume text-only
return
False
def
_check_vision_weights_in_index
(
self
,
index_file
:
str
)
->
bool
:
"""Check if the model.safetensors.index.json contains vision weights."""
try
:
with
open
(
index_file
,
"r"
)
as
f
:
index_data
=
json_lib
.
load
(
f
)
vision_patterns
=
[
"vision_model"
,
"vision_tower"
,
"multi_modal_projector"
]
weight_names
=
index_data
.
get
(
"weight_map"
,
{}).
keys
()
return
any
(
pattern
in
weight_name
for
weight_name
in
weight_names
for
pattern
in
vision_patterns
)
except
(
OSError
,
json_lib
.
JSONDecodeError
,
KeyError
):
return
False
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
()
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
()
...
@@ -59,6 +135,10 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -59,6 +135,10 @@ class Llama4ForConditionalGeneration(nn.Module):
self
,
self
,
items
:
List
[
MultimodalDataItem
],
items
:
List
[
MultimodalDataItem
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# For text-only models, return None or raise an error
if
not
self
.
has_vision
or
self
.
vision_model
is
None
:
raise
ValueError
(
"Vision model not available for text-only checkpoint"
)
pixel_values
=
(
pixel_values
=
(
torch
.
concat
([
item
.
pixel_values
for
item
in
items
])
torch
.
concat
([
item
.
pixel_values
for
item
in
items
])
.
to
(
next
(
self
.
vision_model
.
parameters
()).
device
)
.
to
(
next
(
self
.
vision_model
.
parameters
()).
device
)
...
@@ -79,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -79,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module):
**
kwargs
:
object
,
**
kwargs
:
object
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# For text-only models, pass None for image_data_embedding_func
image_embedding_func
=
self
.
get_image_feature
if
self
.
has_vision
else
None
hs
=
general_mm_embed_routine
(
hs
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
language_model
=
self
.
language_model
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
self
.
get_image_feature
,
image_data_embedding_func
=
image_embedding_func
,
positions
=
positions
,
positions
=
positions
,
)
)
...
@@ -124,7 +207,6 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -124,7 +207,6 @@ class Llama4ForConditionalGeneration(nn.Module):
return
name
,
loaded_weight
return
name
,
loaded_weight
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
(
".self_attn.qkv_proj"
,
".self_attn.q_proj"
,
"q"
),
(
".self_attn.qkv_proj"
,
".self_attn.q_proj"
,
"q"
),
...
@@ -137,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -137,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module):
]
]
params_dict
=
dict
(
self
.
named_parameters
())
params_dict
=
dict
(
self
.
named_parameters
())
num_experts
=
(
self
.
config
.
text_config
.
num_local_experts
if
hasattr
(
self
.
config
,
"text_config"
)
else
self
.
config
.
num_local_experts
)
num_experts
=
self
.
config
.
text_config
.
num_local_experts
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
...
@@ -150,81 +233,279 @@ class Llama4ForConditionalGeneration(nn.Module):
...
@@ -150,81 +233,279 @@ class Llama4ForConditionalGeneration(nn.Module):
)
)
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
not
"vision"
in
name
:
if
self
.
_should_skip_weight
(
name
):
continue
name
=
self
.
_transform_weight_name
(
name
)
if
"vision"
not
in
name
:
name
,
loaded_weight
=
self
.
permute_qk_weight_for_rotary
(
name
,
loaded_weight
=
self
.
permute_qk_weight_for_rotary
(
name
,
loaded_weight
name
,
loaded_weight
)
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
self
.
_handle_scale_remapping
(
name
,
params_dict
):
if
weight_name
not
in
name
:
continue
continue
if
self
.
_handle_stacked_params
(
if
"vision"
in
name
:
name
,
loaded_weight
,
stacked_params_mapping
,
params_dict
continue
):
name
=
name
.
replace
(
weight_name
,
param_name
)
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
if
self
.
_handle_expert_weights
(
weight_loader
(
param
,
loaded_weight
,
shard_id
)
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
,
num_experts
break
):
continue
self
.
_handle_default_weight
(
name
,
loaded_weight
,
params_dict
)
def
_should_skip_weight
(
self
,
name
:
str
)
->
bool
:
"""Check if we should skip loading this weight."""
return
"vision"
in
name
and
not
self
.
has_vision
def
_transform_weight_name
(
self
,
name
:
str
)
->
str
:
"""Transform weight name by adding language_model prefix if needed."""
if
(
not
name
.
startswith
(
"language_model."
)
and
"vision"
not
in
name
and
"multi_modal_projector"
not
in
name
):
return
f
"language_model.
{
name
}
"
return
name
def
_handle_scale_remapping
(
self
,
name
:
str
,
params_dict
:
dict
)
->
bool
:
"""Handle scale parameter remapping. Returns True if handled."""
if
"scale"
in
name
and
"expert"
not
in
name
:
remapped_name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
return
remapped_name
is
None
return
False
def
_handle_stacked_params
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
stacked_params_mapping
:
list
,
params_dict
:
dict
,
)
->
bool
:
"""Handle stacked parameter loading. Returns True if handled."""
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
in
name
and
"vision"
not
in
name
:
transformed_name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
transformed_name
]
param
.
weight_loader
(
param
,
loaded_weight
,
shard_id
)
return
True
return
False
def
_handle_expert_weights
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
expert_params_mapping
:
list
,
params_dict
:
dict
,
num_experts
:
int
,
)
->
bool
:
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
Args:
name: Parameter name from the checkpoint
loaded_weight: The weight tensor to be loaded
expert_params_mapping: Mapping of parameter names to expert configurations
params_dict: Dictionary of model parameters
num_experts: Total number of experts in the MoE layer
Returns:
bool: True if the parameter was handled (is an expert parameter), False otherwise
"""
if
".experts"
not
in
name
:
return
False
if
"experts.gate_up_proj"
not
in
name
and
"experts.down_proj"
not
in
name
:
return
self
.
_handle_other_expert_params
(
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
)
if
"scale"
in
name
:
return
self
.
_handle_expert_scale_params
(
name
,
loaded_weight
,
params_dict
,
num_experts
)
else
:
return
self
.
_handle_expert_weight_params
(
name
,
loaded_weight
,
params_dict
,
num_experts
)
def
_handle_other_expert_params
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
expert_params_mapping
:
list
,
params_dict
:
dict
,
)
->
bool
:
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
Args:
name: Parameter name from the checkpoint
loaded_weight: The weight tensor to be loaded
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
params_dict: Dictionary of model parameters
Returns:
bool: True if parameter was found and handled, False otherwise
"""
for
param_name
,
weight_name
,
expert_id
,
shard_id
in
expert_params_mapping
:
if
weight_name
in
name
:
transformed_name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
transformed_name
]
param
.
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
return
True
return
False
def
_transform_expert_name
(
self
,
name
:
str
,
is_weight
:
bool
=
False
)
->
Tuple
[
str
,
str
,
List
[
str
]]:
"""Transform expert parameter name and get shard information.
Args:
name: The original parameter name
is_weight: Whether this is a weight parameter (adds _weight suffix)
Returns:
Tuple of (transformed_name, shard_id, shard_id_list)
"""
suffix
=
"_weight"
if
is_weight
else
""
if
".gate_up_proj"
in
name
:
transformed_name
=
name
.
replace
(
".experts.gate_up_proj"
,
f
".experts.w13
{
suffix
}
"
)
shard_id
=
"w13"
shard_id_list
=
[
"w1"
,
"w3"
]
else
:
# down_proj
transformed_name
=
name
.
replace
(
".experts.down_proj"
,
f
".experts.w2
{
suffix
}
"
)
shard_id
=
"w2"
shard_id_list
=
[
"w2"
]
return
transformed_name
,
shard_id
,
shard_id_list
def
_handle_expert_scale_params
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
params_dict
:
dict
,
num_experts
:
int
,
)
->
bool
:
"""Handle quantization scale parameters for expert weights.
Args:
name: Parameter name containing scale information
loaded_weight: Scale tensor to be loaded
params_dict: Dictionary of model parameters
num_experts: Total number of experts for broadcast operations
Returns:
bool: True (always handles scale parameters)
"""
import
re
# Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name}
expert_match
=
re
.
search
(
r
"experts\.(\d+)\."
,
name
)
# Transform name
transformed_name
,
_
,
_
=
self
.
_transform_expert_name
(
name
)
if
transformed_name
not
in
params_dict
:
return
True
param
=
params_dict
[
transformed_name
]
# Handle scale parameters
if
expert_match
:
# If we have a specific expert ID, only load for that expert
expert_id
=
int
(
expert_match
.
group
(
1
))
# For scale parameters, we can directly set the value
param
.
data
[
expert_id
]
=
loaded_weight
else
:
# No expert ID found - this is a single scale for all experts
# Load the same scale for all experts
for
expert_id
in
range
(
num_experts
):
param
.
data
[
expert_id
]
=
loaded_weight
return
True
def
_handle_expert_weight_params
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
params_dict
:
dict
,
num_experts
:
int
,
)
->
bool
:
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
Args:
name: Parameter name (should contain gate_up_proj or down_proj)
loaded_weight: Weight tensor(s) to be loaded
params_dict: Dictionary of model parameters
num_experts: Total number of experts for tensor distribution
Returns:
bool: True (always handles weight parameters)
"""
# Transform name and get shard info
transformed_name
,
_
,
shard_id_list
=
self
.
_transform_expert_name
(
name
,
is_weight
=
True
)
if
".gate_up_proj"
in
name
:
loaded_weight_list
=
loaded_weight
.
chunk
(
2
,
dim
=-
1
)
else
:
# down_proj
loaded_weight_list
=
[
loaded_weight
]
for
param_name
,
weight_chunk
,
shard_id
in
zip
(
[
transformed_name
]
*
len
(
shard_id_list
),
loaded_weight_list
,
shard_id_list
):
if
param_name
not
in
params_dict
:
continue
param
=
params_dict
[
param_name
]
weight_loader
=
param
.
weight_loader
# Handle the case where loaded_weight might be a single tensor for all experts
if
weight_chunk
.
dim
()
==
2
:
# Single tensor case - load for all experts
for
expert_id
in
range
(
num_experts
):
weight_loader
(
param
,
weight_chunk
.
T
,
param_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
else
:
else
:
if
".experts"
in
name
:
# Multiple experts case - load each expert's weights
# NOTE: llama4 fp8 has different weight format for experts
for
expert_id
in
range
(
num_experts
):
if
(
weight_loader
(
"experts.gate_up_proj"
not
in
name
param
,
and
"experts.down_proj"
not
in
name
weight_chunk
[
expert_id
].
T
,
):
param_name
,
for
mapping
in
expert_params_mapping
:
shard_id
=
shard_id
,
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
expert_id
=
expert_id
,
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
break
else
:
if
".gate_up_proj"
in
name
:
name_list
=
[
name
.
replace
(
".experts.gate_up_proj"
,
".experts.w13_weight"
)
]
*
2
loaded_weight_list
=
loaded_weight
.
chunk
(
2
,
dim
=-
1
)
shard_id_list
=
[
"w1"
,
"w3"
]
else
:
name_list
=
[
name
.
replace
(
".experts.down_proj"
,
".experts.w2_weight"
)
]
shard_id_list
=
[
"w2"
]
loaded_weight_list
=
[
loaded_weight
]
for
name
,
loaded_weight
,
shard_id
in
zip
(
name_list
,
loaded_weight_list
,
shard_id_list
):
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
for
expert_id
in
range
(
num_experts
):
weight_loader
(
param
,
loaded_weight
[
expert_id
].
T
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
)
weight_loader
(
param
,
loaded_weight
)
return
True
def
_handle_default_weight
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
params_dict
:
dict
):
"""Handle default weight loading."""
# Skip loading extra bias for GPTQ models
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
return
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
set_eagle3_layers_to_capture
(
self
,
layer_ids
:
Optional
[
List
[
int
]]
=
None
):
def
set_eagle3_layers_to_capture
(
self
,
layer_ids
:
Optional
[
List
[
int
]]
=
None
):
if
hasattr
(
self
.
language_model
,
"set_eagle3_layers_to_capture"
):
if
hasattr
(
self
.
language_model
,
"set_eagle3_layers_to_capture"
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment