Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
659907e3
"test/vscode:/vscode.git/clone" did not exist on "fa78fb64761483319661c16bfc1dab279bf63d15"
Unverified
Commit
659907e3
authored
Jul 08, 2025
by
Zhiyu
Committed by
GitHub
Jul 08, 2025
Browse files
Enable ModelOpt Llama4 fp8 checkpoint deployment in SGLang (#7129)
parent
cb9d91ea
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
643 additions
and
81 deletions
+643
-81
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+39
-1
python/sglang/srt/layers/quantization/modelopt_quant.py
python/sglang/srt/layers/quantization/modelopt_quant.py
+244
-1
python/sglang/srt/models/mllama4.py
python/sglang/srt/models/mllama4.py
+360
-79
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
659907e3
...
...
@@ -649,6 +649,27 @@ class FusedMoE(torch.nn.Module):
loaded_weight
:
torch
.
tensor
,
tp_rank
:
int
,
):
"""Load w2 weights for down projection.
Args:
expert_data: The expert data tensor to load into
shard_dim: The dimension to shard along
shard_id: The shard ID (must be "w2")
loaded_weight: The weight tensor to load from
tp_rank: The tensor parallel rank
"""
if
not
isinstance
(
expert_data
,
torch
.
Tensor
)
or
not
isinstance
(
loaded_weight
,
torch
.
Tensor
):
raise
ValueError
(
"expert_data and loaded_weight must be torch.Tensor"
)
if
expert_data
.
dim
()
!=
2
or
loaded_weight
.
dim
()
!=
2
:
raise
ValueError
(
f
"Expected 2D tensors, got expert_data shape
{
expert_data
.
shape
}
and loaded_weight shape
{
loaded_weight
.
shape
}
"
)
if
shard_id
!=
"w2"
:
raise
ValueError
(
f
"shard_id must be 'w2', got
{
shard_id
}
"
)
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
...
...
@@ -669,6 +690,10 @@ class FusedMoE(torch.nn.Module):
if
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
if
shard_size
*
tp_rank
+
shard_size
>
loaded_weight
.
shape
[
shard_dim
]:
raise
ValueError
(
f
"Shard size
{
shard_size
}
at rank
{
tp_rank
}
exceeds loaded_weight dimension
{
loaded_weight
.
shape
[
shard_dim
]
}
"
)
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
)
...
...
@@ -795,8 +820,21 @@ class FusedMoE(torch.nn.Module):
tp_rank
=
tp_rank
,
)
return
if
"ModelOpt"
in
self
.
quant_method
.
__class__
.
__name__
:
if
"weight_scale_2"
in
weight_name
or
"input_scale"
in
weight_name
:
# Determine per-tensor weight scale patterns based on variant
is_fp4_variant
=
(
"ModelOptNvFp4FusedMoEMethod"
in
self
.
quant_method
.
__class__
.
__name__
)
# FP4 uses "weight_scale_2" for per-tensor, FP8 uses "weight_scale" for per-tensor
per_tensor_conditions
=
(
"weight_scale_2"
in
weight_name
if
is_fp4_variant
else
"weight_scale"
in
weight_name
)
or
"input_scale"
in
weight_name
if
per_tensor_conditions
:
self
.
_load_per_tensor_weight_scale
(
shard_id
=
shard_id
,
param
=
param
,
...
...
python/sglang/srt/layers/quantization/modelopt_quant.py
View file @
659907e3
...
...
@@ -26,6 +26,7 @@ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from
sglang.srt.layers.quantization.utils
import
(
convert_to_channelwise
,
is_layer_skipped
,
per_tensor_dequantize
,
requantize_with_max_scale
,
)
from
sglang.srt.layers.radix_attention
import
RadixAttention
...
...
@@ -110,7 +111,12 @@ class ModelOptFp8Config(QuantizationConfig):
self
,
layer
:
torch
.
nn
.
Module
,
prefix
:
str
)
->
Optional
[
"QuantizeMethodBase"
]:
if
self
.
exclude_modules
and
any
(
module
in
prefix
for
module
in
self
.
exclude_modules
module
in
prefix
or
(
prefix
.
startswith
(
"language_model."
)
and
module
in
prefix
.
removeprefix
(
"language_model."
)
)
for
module
in
self
.
exclude_modules
):
return
None
...
...
@@ -119,6 +125,12 @@ class ModelOptFp8Config(QuantizationConfig):
if
self
.
kv_cache_quant_method
and
isinstance
(
layer
,
RadixAttention
):
return
ModelOptFp8KVCacheMethod
(
self
)
# Add MoE support
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
if
isinstance
(
layer
,
FusedMoE
):
return
ModelOptFp8MoEMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
...
...
@@ -234,6 +246,237 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod):
super
().
__init__
(
quant_config
)
class
ModelOptFp8MoEMethod
:
"""MoE method for ModelOpt FP8.
Supports loading FP8 checkpoints with static weight scale and activation scale.
Args:
quant_config: The ModelOpt quantization config.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
"""
Dynamic class composition pattern.
This allows us to effectively "inject" FusedMoEMethodBase as a parent class
at runtime while avoiding circular import issues.
"""
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoEMethodBase
if
not
hasattr
(
cls
,
"_initialized"
):
original_init
=
cls
.
__init__
new_cls
=
type
(
cls
.
__name__
,
(
FusedMoEMethodBase
,),
{
"__init__"
:
original_init
,
**
{
k
:
v
for
k
,
v
in
cls
.
__dict__
.
items
()
if
k
!=
"__dict__"
},
},
)
obj
=
super
(
new_cls
,
new_cls
).
__new__
(
new_cls
)
obj
.
__init__
(
*
args
,
**
kwargs
)
return
obj
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
quant_config
:
ModelOptFp8Config
):
self
.
quant_config
=
quant_config
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
):
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoeWeightScaleSupported
# Use FP8 dtype if checkpoint is serialized, otherwise use the default dtype
weight_dtype
=
(
torch
.
float8_e4m3fn
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
else
params_dtype
)
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
w13_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
hidden_size
,
dtype
=
weight_dtype
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
w2_weight
=
ModelWeightParameter
(
data
=
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size
,
dtype
=
weight_dtype
),
input_dim
=
2
,
output_dim
=
1
,
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
if
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# WEIGHT SCALES - Per-tensor scaling for ModelOpts
# Allocate 2 scales for w1 and w3 respectively.
# They will be combined to a single scale after weight loading.
w13_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
(
num_experts
,
2
),
torch
.
finfo
(
torch
.
float32
).
min
,
dtype
=
torch
.
float32
,
),
weight_loader
=
weight_loader
,
)
w2_weight_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
(
(
num_experts
,),
torch
.
finfo
(
torch
.
float32
).
min
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Set weight loader attributes for scales
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
TENSOR
.
value
}
)
# INPUT SCALES - Per-tensor scaling for ModelOpt
w13_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
w2_input_scale
=
PerTensorScaleParameter
(
data
=
torch
.
full
((
num_experts
,),
1.0
,
dtype
=
torch
.
float32
),
weight_loader
=
weight_loader
,
)
layer
.
register_parameter
(
"w13_input_scale"
,
w13_input_scale
)
layer
.
register_parameter
(
"w2_input_scale"
,
w2_input_scale
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
"""Process FP8 MoE weights after loading from serialized checkpoint.
Only supports pre-quantized checkpoints with FP8 weights and scales.
"""
layer
.
w13_weight
=
Parameter
(
layer
.
w13_weight
.
data
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
layer
.
w2_weight
.
data
,
requires_grad
=
False
)
# Handle scale parameters
if
hasattr
(
layer
,
"w13_weight_scale"
)
and
layer
.
w13_weight_scale
is
not
None
:
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max of the w1 and w3 scales then dequant and requant each expert.
if
layer
.
w13_weight_scale
.
dim
()
==
2
:
# Shape: (num_experts, 2)
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
# Get the maximum scale across w1 and w3 for each expert
max_w13_scales
=
layer
.
w13_weight_scale
.
max
(
dim
=
1
).
values
# Requantize each expert's weights using the combined scale
# w13_weight has shape (num_experts, 2 * intermediate_size, hidden_size)
# where the first intermediate_size rows are w1, the next are w3
intermediate_size
=
layer
.
w13_weight
.
shape
[
1
]
//
2
for
expert_id
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
start
=
0
for
shard_id
in
range
(
2
):
# w1 and w3
# Dequantize using the original scale for this shard
dq_weight
=
per_tensor_dequantize
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:
],
layer
.
w13_weight_scale
[
expert_id
][
shard_id
],
)
# Requantize using the combined max scale
(
layer
.
w13_weight
[
expert_id
][
start
:
start
+
intermediate_size
,
:
],
_
,
)
=
scaled_fp8_quant
(
dq_weight
,
max_w13_scales
[
expert_id
])
start
+=
intermediate_size
# Update the scale parameter to be per-expert instead of per-shard
layer
.
w13_weight_scale
=
Parameter
(
max_w13_scales
,
requires_grad
=
False
)
else
:
layer
.
w13_weight_scale
=
Parameter
(
layer
.
w13_weight_scale
.
data
,
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_weight_scale"
)
and
layer
.
w2_weight_scale
is
not
None
:
layer
.
w2_weight_scale
=
Parameter
(
layer
.
w2_weight_scale
.
data
,
requires_grad
=
False
)
if
hasattr
(
layer
,
"w13_input_scale"
)
and
layer
.
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
Parameter
(
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
if
hasattr
(
layer
,
"w2_input_scale"
)
and
layer
.
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
,
top_k
:
int
,
renormalize
:
bool
,
use_grouped_topk
:
bool
,
topk_group
:
Optional
[
int
]
=
None
,
num_expert_group
:
Optional
[
int
]
=
None
,
num_fused_shared_experts
:
Optional
[
int
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
activation
:
str
=
"silu"
,
apply_router_weight_on_input
:
bool
=
False
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.topk
import
select_experts
# Expert selection
topk_weights
,
topk_ids
=
select_experts
(
hidden_states
=
x
,
router_logits
=
router_logits
,
use_grouped_topk
=
use_grouped_topk
,
top_k
=
top_k
,
renormalize
=
renormalize
,
topk_group
=
topk_group
,
num_expert_group
=
num_expert_group
,
num_fused_shared_experts
=
num_fused_shared_experts
,
custom_routing_function
=
custom_routing_function
,
correction_bias
=
correction_bias
,
routed_scaling_factor
=
routed_scaling_factor
,
)
return
fused_experts
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
False
,
# ModelOpt uses per-tensor quantization
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
no_combine
=
no_combine
,
)
class
ModelOptFp4Config
(
QuantizationConfig
):
"""Config class for FP4."""
...
...
python/sglang/srt/models/mllama4.py
View file @
659907e3
import
json
as
json_lib
import
logging
import
os
from
collections.abc
import
Iterable
from
typing
import
List
,
Optional
,
Set
,
Tuple
...
...
@@ -19,6 +22,13 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from
sglang.srt.utils
import
add_prefix
,
is_cpu
_is_cpu
=
is_cpu
()
from
sglang.srt.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
,
)
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
class
Llama4ForConditionalGeneration
(
nn
.
Module
):
...
...
@@ -37,19 +47,85 @@ class Llama4ForConditionalGeneration(nn.Module):
self
.
config
=
config
self
.
quant_config
=
quant_config
# Check if this is a text-only model (modelopt fp8 llama4 has no vision components)
self
.
has_vision
=
self
.
_has_vision_weights
(
config
)
if
not
self
.
has_vision
:
logger
.
warning
(
"No vision weights found in checkpoint. Model will run in text-only mode. "
"Multimodal capabilities (image processing) will be unavailable."
)
if
self
.
has_vision
:
self
.
vision_model
=
Llama4VisionModel
(
config
.
vision_config
)
self
.
multi_modal_projector
=
Llama4MultiModalProjector
(
config
)
else
:
self
.
vision_model
=
None
self
.
multi_modal_projector
=
None
# Initialize the language model
from
sglang.srt.models.llama4
import
Llama4ForCausalLM
self
.
language_model
=
Llama4ForCausalLM
(
config
.
text_config
,
config
.
text_config
if
hasattr
(
config
,
"text_config"
)
else
config
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"language_model"
,
prefix
),
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
text_config
if
hasattr
(
config
,
"text_config"
)
else
config
)
def
_has_vision_weights
(
self
,
config
)
->
bool
:
"""Check if the model has vision components by examining the checkpoint."""
model_path
=
getattr
(
config
,
"_name_or_path"
,
None
)
if
not
model_path
:
return
False
# Check if this is a local path first
if
os
.
path
.
isdir
(
model_path
):
index_file
=
os
.
path
.
join
(
model_path
,
"model.safetensors.index.json"
)
if
os
.
path
.
exists
(
index_file
):
return
self
.
_check_vision_weights_in_index
(
index_file
)
# For HuggingFace models, we need to check the actual checkpoint
# The config might say it's multimodal, but the checkpoint might be text-only
try
:
# Try to access the HuggingFace cache directory
from
huggingface_hub
import
try_to_load_from_cache
# Check if index file exists in cache
index_file_path
=
try_to_load_from_cache
(
repo_id
=
model_path
,
filename
=
"model.safetensors.index.json"
,
cache_dir
=
None
,
)
if
index_file_path
and
os
.
path
.
exists
(
index_file_path
):
return
self
.
_check_vision_weights_in_index
(
index_file_path
)
except
Exception
:
# If we can't access the cache, fall back to config-based detection
pass
# Fallback, assume text-only
return
False
def
_check_vision_weights_in_index
(
self
,
index_file
:
str
)
->
bool
:
"""Check if the model.safetensors.index.json contains vision weights."""
try
:
with
open
(
index_file
,
"r"
)
as
f
:
index_data
=
json_lib
.
load
(
f
)
vision_patterns
=
[
"vision_model"
,
"vision_tower"
,
"multi_modal_projector"
]
weight_names
=
index_data
.
get
(
"weight_map"
,
{}).
keys
()
return
any
(
pattern
in
weight_name
for
weight_name
in
weight_names
for
pattern
in
vision_patterns
)
except
(
OSError
,
json_lib
.
JSONDecodeError
,
KeyError
):
return
False
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
mm_inputs
:
MultimodalInputs
):
pattern
=
MultiModalityDataPaddingPatternMultimodalTokens
()
...
...
@@ -59,6 +135,10 @@ class Llama4ForConditionalGeneration(nn.Module):
self
,
items
:
List
[
MultimodalDataItem
],
)
->
torch
.
Tensor
:
# For text-only models, return None or raise an error
if
not
self
.
has_vision
or
self
.
vision_model
is
None
:
raise
ValueError
(
"Vision model not available for text-only checkpoint"
)
pixel_values
=
(
torch
.
concat
([
item
.
pixel_values
for
item
in
items
])
.
to
(
next
(
self
.
vision_model
.
parameters
()).
device
)
...
...
@@ -79,11 +159,14 @@ class Llama4ForConditionalGeneration(nn.Module):
**
kwargs
:
object
,
)
->
torch
.
Tensor
:
# For text-only models, pass None for image_data_embedding_func
image_embedding_func
=
self
.
get_image_feature
if
self
.
has_vision
else
None
hs
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
language_model
=
self
.
language_model
,
image_data_embedding_func
=
self
.
get_image_feature
,
image_data_embedding_func
=
image_embedding_func
,
positions
=
positions
,
)
...
...
@@ -124,7 +207,6 @@ class Llama4ForConditionalGeneration(nn.Module):
return
name
,
loaded_weight
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]])
->
Set
[
str
]:
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
".self_attn.qkv_proj"
,
".self_attn.q_proj"
,
"q"
),
...
...
@@ -137,11 +219,12 @@ class Llama4ForConditionalGeneration(nn.Module):
]
params_dict
=
dict
(
self
.
named_parameters
())
num_experts
=
(
self
.
config
.
text_config
.
num_local_experts
if
hasattr
(
self
.
config
,
"text_config"
)
else
self
.
config
.
num_local_experts
)
num_experts
=
self
.
config
.
text_config
.
num_local_experts
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
...
...
@@ -150,80 +233,278 @@ class Llama4ForConditionalGeneration(nn.Module):
)
for
name
,
loaded_weight
in
weights
:
if
not
"vision"
in
name
:
if
self
.
_should_skip_weight
(
name
):
continue
name
=
self
.
_transform_weight_name
(
name
)
if
"vision"
not
in
name
:
name
,
loaded_weight
=
self
.
permute_qk_weight_for_rotary
(
name
,
loaded_weight
)
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
if
self
.
_handle_scale_remapping
(
name
,
params_dict
):
continue
if
self
.
_handle_stacked_params
(
name
,
loaded_weight
,
stacked_params_mapping
,
params_dict
):
continue
if
"vision"
in
name
:
if
self
.
_handle_expert_weights
(
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
,
num_experts
):
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
if
".experts"
in
name
:
# NOTE: llama4 fp8 has different weight format for experts
self
.
_handle_default_weight
(
name
,
loaded_weight
,
params_dict
)
def
_should_skip_weight
(
self
,
name
:
str
)
->
bool
:
"""Check if we should skip loading this weight."""
return
"vision"
in
name
and
not
self
.
has_vision
def
_transform_weight_name
(
self
,
name
:
str
)
->
str
:
"""Transform weight name by adding language_model prefix if needed."""
if
(
"experts.gate_up_proj"
not
in
name
and
"experts.down_proj"
not
in
name
not
name
.
startswith
(
"language_model."
)
and
"vision"
not
in
name
and
"multi_modal_projector"
not
in
name
):
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
return
f
"language_model.
{
name
}
"
return
name
def
_handle_scale_remapping
(
self
,
name
:
str
,
params_dict
:
dict
)
->
bool
:
"""Handle scale parameter remapping. Returns True if handled."""
if
"scale"
in
name
and
"expert"
not
in
name
:
remapped_name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
return
remapped_name
is
None
return
False
def
_handle_stacked_params
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
stacked_params_mapping
:
list
,
params_dict
:
dict
,
)
->
bool
:
"""Handle stacked parameter loading. Returns True if handled."""
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
in
name
and
"vision"
not
in
name
:
transformed_name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
transformed_name
]
param
.
weight_loader
(
param
,
loaded_weight
,
shard_id
)
return
True
return
False
def
_handle_expert_weights
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
expert_params_mapping
:
list
,
params_dict
:
dict
,
num_experts
:
int
,
)
->
bool
:
"""Handle expert weight loading for MoE (Mixture of Experts) layers.
Args:
name: Parameter name from the checkpoint
loaded_weight: The weight tensor to be loaded
expert_params_mapping: Mapping of parameter names to expert configurations
params_dict: Dictionary of model parameters
num_experts: Total number of experts in the MoE layer
Returns:
bool: True if the parameter was handled (is an expert parameter), False otherwise
"""
if
".experts"
not
in
name
:
return
False
if
"experts.gate_up_proj"
not
in
name
and
"experts.down_proj"
not
in
name
:
return
self
.
_handle_other_expert_params
(
name
,
loaded_weight
,
expert_params_mapping
,
params_dict
)
if
"scale"
in
name
:
return
self
.
_handle_expert_scale_params
(
name
,
loaded_weight
,
params_dict
,
num_experts
)
break
else
:
return
self
.
_handle_expert_weight_params
(
name
,
loaded_weight
,
params_dict
,
num_experts
)
def
_handle_other_expert_params
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
expert_params_mapping
:
list
,
params_dict
:
dict
,
)
->
bool
:
"""Handle expert parameters that are not gate_up_proj or down_proj weights.
Args:
name: Parameter name from the checkpoint
loaded_weight: The weight tensor to be loaded
expert_params_mapping: List of tuples mapping checkpoint names to model parameters
params_dict: Dictionary of model parameters
Returns:
bool: True if parameter was found and handled, False otherwise
"""
for
param_name
,
weight_name
,
expert_id
,
shard_id
in
expert_params_mapping
:
if
weight_name
in
name
:
transformed_name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
transformed_name
]
param
.
weight_loader
(
param
,
loaded_weight
,
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
)
return
True
return
False
def
_transform_expert_name
(
self
,
name
:
str
,
is_weight
:
bool
=
False
)
->
Tuple
[
str
,
str
,
List
[
str
]]:
"""Transform expert parameter name and get shard information.
Args:
name: The original parameter name
is_weight: Whether this is a weight parameter (adds _weight suffix)
Returns:
Tuple of (transformed_name, shard_id, shard_id_list)
"""
suffix
=
"_weight"
if
is_weight
else
""
if
".gate_up_proj"
in
name
:
name_list
=
[
name
.
replace
(
".experts.gate_up_proj"
,
".experts.w13_weight"
transformed_name
=
name
.
replace
(
".experts.gate_up_proj"
,
f
".experts.w13
{
suffix
}
"
)
]
*
2
loaded_weight_list
=
loaded_weight
.
chunk
(
2
,
dim
=-
1
)
shard_id
=
"w13"
shard_id_list
=
[
"w1"
,
"w3"
]
else
:
name_list
=
[
name
.
replace
(
".experts.down_proj"
,
".experts.w2_weight"
)
]
else
:
# down_proj
transformed_name
=
name
.
replace
(
".experts.down_proj"
,
f
".experts.w2
{
suffix
}
"
)
shard_id
=
"w2"
shard_id_list
=
[
"w2"
]
return
transformed_name
,
shard_id
,
shard_id_list
def
_handle_expert_scale_params
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
params_dict
:
dict
,
num_experts
:
int
,
)
->
bool
:
"""Handle quantization scale parameters for expert weights.
Args:
name: Parameter name containing scale information
loaded_weight: Scale tensor to be loaded
params_dict: Dictionary of model parameters
num_experts: Total number of experts for broadcast operations
Returns:
bool: True (always handles scale parameters)
"""
import
re
# Check if this matches the expert parameter pattern: experts.{expert_id}.{param_name}
expert_match
=
re
.
search
(
r
"experts\.(\d+)\."
,
name
)
# Transform name
transformed_name
,
_
,
_
=
self
.
_transform_expert_name
(
name
)
if
transformed_name
not
in
params_dict
:
return
True
param
=
params_dict
[
transformed_name
]
# Handle scale parameters
if
expert_match
:
# If we have a specific expert ID, only load for that expert
expert_id
=
int
(
expert_match
.
group
(
1
))
# For scale parameters, we can directly set the value
param
.
data
[
expert_id
]
=
loaded_weight
else
:
# No expert ID found - this is a single scale for all experts
# Load the same scale for all experts
for
expert_id
in
range
(
num_experts
):
param
.
data
[
expert_id
]
=
loaded_weight
return
True
def
_handle_expert_weight_params
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
params_dict
:
dict
,
num_experts
:
int
,
)
->
bool
:
"""Handle actual weight tensors for expert layers (gate_up_proj and down_proj).
Args:
name: Parameter name (should contain gate_up_proj or down_proj)
loaded_weight: Weight tensor(s) to be loaded
params_dict: Dictionary of model parameters
num_experts: Total number of experts for tensor distribution
Returns:
bool: True (always handles weight parameters)
"""
# Transform name and get shard info
transformed_name
,
_
,
shard_id_list
=
self
.
_transform_expert_name
(
name
,
is_weight
=
True
)
if
".gate_up_proj"
in
name
:
loaded_weight_list
=
loaded_weight
.
chunk
(
2
,
dim
=-
1
)
else
:
# down_proj
loaded_weight_list
=
[
loaded_weight
]
for
name
,
loaded_weight
,
shard_id
in
zip
(
name_list
,
loaded_weight_list
,
shard_id_list
for
param_name
,
weight_chunk
,
shard_id
in
zip
(
[
transformed_name
]
*
len
(
shard_id_list
),
loaded_weight_list
,
shard_id_list
):
param
=
params_dict
[
name
]
if
param_name
not
in
params_dict
:
continue
param
=
params_dict
[
param_name
]
weight_loader
=
param
.
weight_loader
# Handle the case where loaded_weight might be a single tensor for all experts
if
weight_chunk
.
dim
()
==
2
:
# Single tensor case - load for all experts
for
expert_id
in
range
(
num_experts
):
weight_loader
(
param
,
loaded_weight
[
expert_id
]
.
T
,
name
,
weight_chunk
.
T
,
param_
name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
else
:
# Skip loading extra bias for GPTQ models.
# Multiple experts case - load each expert's weights
for
expert_id
in
range
(
num_experts
):
weight_loader
(
param
,
weight_chunk
[
expert_id
].
T
,
param_name
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
)
return
True
def
_handle_default_weight
(
self
,
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
params_dict
:
dict
):
"""Handle default weight loading."""
# Skip loading extra bias for GPTQ models
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
return
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
def
set_eagle3_layers_to_capture
(
self
,
layer_ids
:
Optional
[
List
[
int
]]
=
None
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment