Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4060ed37
Unverified
Commit
4060ed37
authored
Oct 24, 2025
by
Yuxuan Zhang
Committed by
GitHub
Oct 24, 2025
Browse files
Refactoring GLM-4.5 and GLM-4.5V related implementations (#11800)
parent
2342605e
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
356 additions
and
565 deletions
+356
-565
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+322
-354
python/sglang/srt/models/glm4_moe_nextn.py
python/sglang/srt/models/glm4_moe_nextn.py
+4
-14
python/sglang/srt/models/glm4v_moe.py
python/sglang/srt/models/glm4v_moe.py
+29
-196
python/sglang/srt/multimodal/processors/glm4v.py
python/sglang/srt/multimodal/processors/glm4v.py
+1
-1
No files found.
python/sglang/srt/models/glm4_moe.py
View file @
4060ed37
This diff is collapsed.
Click to expand it.
python/sglang/srt/models/glm4_moe_nextn.py
View file @
4060ed37
...
...
@@ -12,7 +12,8 @@
# limitations under the License.
# ==============================================================================
"""Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
"""Inference-only GLM-4.5, GLM-4.6 Speculative Decoding."""
import
logging
from
typing
import
Iterable
,
Optional
,
Tuple
...
...
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.models.glm4_moe
import
Glm4MoeDecoderLayer
,
Glm4MoeForCausalLM
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
BumpAllocator
,
add_prefix
from
sglang.srt.utils
import
add_prefix
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -84,14 +85,6 @@ class Glm4MoeModelNextN(nn.Module):
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
zero_allocator
=
BumpAllocator
(
buffer_size
=
2
,
dtype
=
torch
.
float32
,
device
=
(
input_embeds
.
device
if
input_embeds
is
not
None
else
input_ids
.
device
),
)
if
input_embeds
is
None
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
else
:
...
...
@@ -111,7 +104,7 @@ class Glm4MoeModelNextN(nn.Module):
residual
=
None
with
get_global_expert_distribution_recorder
().
disable_this_region
():
hidden_states
,
residual
=
self
.
decoder
(
positions
,
hidden_states
,
forward_batch
,
residual
,
zero_allocator
positions
,
hidden_states
,
forward_batch
,
residual
)
if
not
forward_batch
.
forward_mode
.
is_idle
():
...
...
@@ -124,7 +117,6 @@ class Glm4MoeModelNextN(nn.Module):
class
Glm4MoeForCausalLMNextN
(
Glm4MoeForCausalLM
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
...
...
@@ -135,8 +127,6 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
quant_config
=
quant_config
self
.
determine_num_fused_shared_experts
(
"Glm4MoeForCausalLMNextN"
)
self
.
model
=
Glm4MoeModelNextN
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
...
...
python/sglang/srt/models/glm4v_moe.py
View file @
4060ed37
...
...
@@ -6,13 +6,10 @@ import torch
import
torch.nn
as
nn
from
transformers.models.glm4v_moe.configuration_glm4v_moe
import
Glm4vMoeConfig
from
sglang.srt.distributed
import
(
get_moe_expert_parallel_world_size
,
get_tensor_model_parallel_world_size
,
)
from
sglang.srt.distributed
import
get_tensor_model_parallel_world_size
from
sglang.srt.layers.attention
import
vision_utils
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.moe.fused_moe_triton
import
FusedMoE
from
sglang.srt.layers.moe.fused_moe_triton
.layer
import
FusedMoE
from
sglang.srt.layers.pooler
import
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.vocab_parallel_embedding
import
ParallelLMHead
...
...
@@ -20,7 +17,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from
sglang.srt.models.glm4_moe
import
Glm4MoeModel
from
sglang.srt.models.glm4v
import
Glm4vForConditionalGeneration
,
Glm4vVisionModel
from
sglang.srt.server_args
import
get_global_server_args
from
sglang.srt.utils
import
add_prefix
,
is_cuda
,
log_info_on_rank0
from
sglang.srt.utils
import
add_prefix
,
is_cuda
from
sglang.srt.utils.hf_transformers_utils
import
get_processor
_is_cuda
=
is_cuda
()
...
...
@@ -39,12 +36,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
)
->
None
:
nn
.
Module
.
__init__
(
self
)
config
.
moe_layer_freq
=
1
self
.
config
=
config
vision_utils
.
update_vit_attn_dummy_heads_config
(
self
.
config
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
quant_config
=
quant_config
self
.
determine_num_fused_shared_experts
(
"Glm4MoeForCausalLM"
)
self
.
num_fused_shared_experts
=
(
0
if
get_global_server_args
().
disable_shared_experts_fusion
...
...
@@ -77,38 +72,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# For EAGLE3 support
self
.
capture_aux_hidden_states
=
False
def
determine_num_fused_shared_experts
(
self
,
architecture
:
str
=
"Glm4MoeForCausalLM"
):
self
.
num_fused_shared_experts
=
0
if
get_global_server_args
().
disable_shared_experts_fusion
:
return
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
disable_reason
=
None
if
(
not
_is_cuda
or
torch
.
cuda
.
get_device_capability
(
"cuda"
)
<
(
8
,
0
)
or
self
.
config
.
architectures
[
0
]
!=
architecture
or
self
.
config
.
n_shared_experts
!=
1
):
disable_reason
=
"Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
elif
get_moe_expert_parallel_world_size
()
>
1
:
disable_reason
=
"Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
if
disable_reason
is
not
None
:
get_global_server_args
().
disable_shared_experts_fusion
=
True
self
.
num_fused_shared_experts
=
0
log_info_on_rank0
(
logger
,
f
"
{
disable_reason
}
Shared experts fusion optimization is disabled."
,
)
return
self
.
num_fused_shared_experts
=
self
.
config
.
n_shared_experts
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
is_nextn
=
False
):
if
is_nextn
:
if
hasattr
(
self
.
config
,
"num_nextn_predict_layers"
):
num_nextn_layers
=
self
.
config
.
num_nextn_predict_layers
...
...
@@ -130,117 +94,14 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
if
self
.
num_fused_shared_experts
>
0
:
assert
self
.
num_fused_shared_experts
==
1
weights_list
=
list
(
weights
)
weights_dict
=
dict
(
weights_list
)
if
self
.
quant_config
is
not
None
:
if
self
.
quant_config
.
get_name
()
==
"w8a8_int8"
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
]
elif
(
self
.
quant_config
.
get_name
()
==
"fp8"
or
self
.
quant_config
.
get_name
()
==
"blockwise_int8"
or
self
.
quant_config
.
get_name
()
==
"compressed_tensors"
):
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
]
elif
self
.
quant_config
.
get_name
()
==
"awq"
:
suffix_list
=
[
"down_proj.qweight"
,
"down_proj.qzeros"
,
"down_proj.scales"
,
"gate_proj.qweight"
,
"gate_proj.qzeros"
,
"gate_proj.scales"
,
"up_proj.qweight"
,
"up_proj.qzeros"
,
"up_proj.scales"
,
]
elif
self
.
quant_config
.
get_name
()
==
"modelopt_fp4"
:
suffix_list
=
[
"down_proj.weight"
,
"down_proj.weight_scale"
,
"down_proj.weight_scale_2"
,
"down_proj.input_scale"
,
"gate_proj.weight"
,
"gate_proj.weight_scale"
,
"gate_proj.weight_scale_2"
,
"gate_proj.input_scale"
,
"up_proj.weight"
,
"up_proj.weight_scale"
,
"up_proj.weight_scale_2"
,
"up_proj.input_scale"
,
]
else
:
raise
ValueError
(
f
"Unsupported shared expert fusion for quantization:
{
self
.
quant_config
.
get_name
()
}
."
)
else
:
suffix_list
=
[
"down_proj.weight"
,
"gate_proj.weight"
,
"up_proj.weight"
,
]
names_to_remove
=
[]
moe_layers
=
(
range
(
self
.
config
.
first_k_dense_replace
,
self
.
config
.
num_hidden_layers
,
self
.
config
.
moe_layer_freq
,
)
if
not
is_nextn
else
[
nextn_layer_id
]
)
for
moe_layer
in
moe_layers
:
for
suffix
in
suffix_list
:
shared_expert_weight_name
=
(
f
"model.layers.
{
moe_layer
}
.mlp.shared_experts.
{
suffix
}
"
)
# online fp8 quantization does not load weight_scale
if
shared_expert_weight_name
not
in
weights_dict
:
continue
weights_list
.
append
(
(
f
"model.layers.
{
moe_layer
}
."
f
"mlp.experts."
f
"
{
self
.
config
.
n_routed_experts
+
0
}
"
f
".
{
suffix
}
"
,
weights_dict
[
shared_expert_weight_name
],
)
)
names_to_remove
+=
[
shared_expert_weight_name
]
weights
=
[
w
for
w
in
weights_list
if
w
[
0
]
not
in
names_to_remove
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"gate_proj"
,
ckpt_down_proj_name
=
"down_proj"
,
ckpt_up_proj_name
=
"up_proj"
,
num_experts
=
self
.
config
.
n_routed_experts
+
self
.
num_fused_shared_experts
,
num_experts
=
self
.
config
.
n_routed_experts
,
)
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
fuse_qkv_a_proj
=
hasattr
(
self
.
config
,
"q_lora_rank"
)
and
(
self
.
config
.
q_lora_rank
is
not
None
)
cached_a_proj
=
{}
if
fuse_qkv_a_proj
else
None
if
is_nextn
:
nextn_layer_prefix
=
f
"model.layers.
{
nextn_layer_id
}
"
nextn_spec_weight_names
=
[
...
...
@@ -300,23 +161,36 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if
(
"mlp.experts
.
"
in
name
)
and
name
not
in
params_dict
:
if
"mlp.experts"
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
if
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Track if this is an expert weight to enable early skipping
is_expert_weight
=
False
for
mapping
in
expert_params_mapping
:
param_name
,
weight_name
,
expert_id
,
shard_id
=
mapping
if
weight_name
not
in
name
:
continue
# Mark as expert weight regardless of whether we can process it
is_expert_weight
=
True
name
=
name
.
replace
(
weight_name
,
param_name
)
if
name
not
in
params_dict
:
# Expert weight not on this rank, will be skipped below
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
...
...
@@ -328,64 +202,21 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
)
break
else
:
if
is_expert_weight
:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
if
"visual"
in
name
:
# adapt to VisionAttention
# adapt to VisionAttention
for GLM-V
name
=
name
.
replace
(
r
"attn.qkv."
,
r
"attn.qkv_proj."
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
fuse_qkv_a_proj
and
(
"q_a_proj"
in
name
or
"kv_a_proj_with_mqa"
in
name
):
cached_a_proj
[
name
]
=
loaded_weight
q_a_proj_name
=
(
name
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"q_a_proj"
)
)
kv_a_proj_name
=
(
name
if
"kv_a_proj_with_mqa"
in
name
else
name
.
replace
(
"q_a_proj"
,
"kv_a_proj_with_mqa"
)
)
# When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
if
(
q_a_proj_name
in
cached_a_proj
and
kv_a_proj_name
in
cached_a_proj
):
q_a_proj_weight
=
cached_a_proj
[
q_a_proj_name
]
kv_a_proj_weight
=
cached_a_proj
[
kv_a_proj_name
]
fused_weight
=
torch
.
cat
(
[
q_a_proj_weight
,
kv_a_proj_weight
],
dim
=
0
)
param_name
=
(
name
.
replace
(
"q_a_proj"
,
"fused_qkv_a_proj_with_mqa"
)
if
"q_a_proj"
in
name
else
name
.
replace
(
"kv_a_proj_with_mqa"
,
"fused_qkv_a_proj_with_mqa"
)
)
param
=
params_dict
[
param_name
]
if
name
not
in
params_dict
:
continue
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
fused_weight
)
cached_a_proj
.
pop
(
q_a_proj_name
)
cached_a_proj
.
pop
(
kv_a_proj_name
)
else
:
if
(
"k_scale"
in
name
or
"v_scale"
in
name
)
and
name
not
in
params_dict
:
# modelopt attn kv scale is named differently
if
any
(
scale
in
name
for
scale
in
[
"k_scale"
,
"v_scale"
]):
name
=
name
.
replace
(
"_proj"
,
"attn_mqa"
)
else
:
logger
.
warning
(
f
"Unknown scale found in checkpoint:
{
name
}
"
)
if
name
in
params_dict
.
keys
():
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
...
...
@@ -395,6 +226,8 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
self
.
config
,
name
,
loaded_weight
)
weight_loader
(
param
,
loaded_weight
)
else
:
logger
.
warning
(
f
"Parameter
{
name
}
not found in params_dict"
)
EntryClass
=
[
Glm4vMoeForConditionalGeneration
]
python/sglang/srt/multimodal/processors/glm4v.py
View file @
4060ed37
...
...
@@ -17,7 +17,7 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
,
*
args
,
**
kwargs
)
# GLM-
4.1V and GLM-4.5
V specific tokens
# GLM-V specific tokens
self
.
IMAGE_TOKEN
=
"<|image|>"
self
.
VIDEO_TOKEN
=
"<|video|>"
self
.
IMAGE_START_TOKEN
=
"<|begin_of_image|>"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment