Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a889c854
Unverified
Commit
a889c854
authored
Nov 06, 2025
by
Yuxuan Zhang
Committed by
GitHub
Nov 06, 2025
Browse files
[Grammar Fix] GLM-4-MOE self.first_k_dense_replace is undefined. (#12455)
parent
4d84f886
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
214 additions
and
65 deletions
+214
-65
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+214
-65
No files found.
python/sglang/srt/models/glm4_moe.py
View file @
a889c854
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
import
logging
import
logging
from
typing
import
Any
,
Dict
,
Iterable
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -84,6 +84,7 @@ from sglang.srt.utils import (
...
@@ -84,6 +84,7 @@ from sglang.srt.utils import (
is_cpu
,
is_cpu
,
is_cuda
,
is_cuda
,
is_hip
,
is_hip
,
is_non_idle_and_non_empty
,
make_layers
,
make_layers
,
)
)
...
@@ -142,14 +143,17 @@ class Glm4MoeMLP(nn.Module):
...
@@ -142,14 +143,17 @@ class Glm4MoeMLP(nn.Module):
self
,
self
,
x
,
x
,
forward_batch
=
None
,
forward_batch
=
None
,
should_allreduce_fusion
=
False
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
):
):
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
if
(
self
.
tp_size
==
1
)
and
x
.
shape
[
0
]
==
0
:
return
x
return
x
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
should_allreduce_fusion
)
x
,
_
=
self
.
down_proj
(
x
,
skip_all_reduce
=
should_allreduce_fusion
or
use_reduce_scatter
)
return
x
return
x
...
@@ -442,63 +446,14 @@ class Glm4MoeSparseMoeBlock(nn.Module):
...
@@ -442,63 +446,14 @@ class Glm4MoeSparseMoeBlock(nn.Module):
should_allreduce_fusion
:
bool
=
False
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
not
self
.
_enable_a2a_moe
:
DUAL_STREAM_TOKEN_THRESHOLD
=
1024
if
not
get_moe_a2a_backend
().
is_deepep
():
if
(
return
self
.
forward_normal
(
self
.
alt_stream
is
not
None
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
and
hidden_states
.
shape
[
0
]
>
0
)
and
hidden_states
.
shape
[
0
]
<=
DUAL_STREAM_TOKEN_THRESHOLD
):
return
self
.
forward_normal_dual_stream
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
)
else
:
return
self
.
forward_normal
(
hidden_states
,
should_allreduce_fusion
,
use_reduce_scatter
,
)
else
:
else
:
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
return
self
.
forward_deepep
(
hidden_states
,
forward_batch
)
def
forward_normal_dual_stream
(
self
,
hidden_states
:
torch
.
Tensor
,
should_allreduce_fusion
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# router_logits: (num_tokens, n_experts)
router_logits
=
self
.
gate
(
hidden_states
)
topk_output
=
self
.
topk
(
hidden_states
,
router_logits
)
final_hidden_states
=
self
.
experts
(
hidden_states
,
topk_output
)
if
not
_is_cuda
:
final_hidden_states
*=
self
.
routed_scaling_factor
current_stream
.
wait_stream
(
self
.
alt_stream
)
with
use_symmetric_memory
(
parallel_state
.
get_tp_group
(),
disabled
=
not
is_allocation_symmetric
()
):
final_hidden_states_out
=
torch
.
empty_like
(
final_hidden_states
)
torch
.
add
(
final_hidden_states
,
shared_output
,
out
=
final_hidden_states_out
)
final_hidden_states
=
final_hidden_states_out
if
(
self
.
tp_size
>
1
and
not
should_allreduce_fusion
and
not
use_reduce_scatter
and
not
should_use_flashinfer_cutlass_moe_fp4_allgather
()
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
def
forward_normal
(
def
forward_normal
(
self
,
self
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -534,11 +489,13 @@ class Glm4MoeSparseMoeBlock(nn.Module):
...
@@ -534,11 +489,13 @@ class Glm4MoeSparseMoeBlock(nn.Module):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
def
_forward_deepep
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
):
def
forward_deepep
(
self
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
)
->
torch
.
Tensor
:
shared_output
=
None
shared_output
=
None
if
hidden_states
.
shape
[
0
]
>
0
:
if
hidden_states
.
shape
[
0
]
>
0
:
# router_logits: (num_tokens, n_experts)
# router_logits: (num_tokens, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
router_logits
=
self
.
gate
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
shared_output
=
self
.
_forward_shared_experts
(
hidden_states
)
topk_output
=
self
.
topk
(
topk_output
=
self
.
topk
(
hidden_states
,
hidden_states
,
...
@@ -556,7 +513,15 @@ class Glm4MoeSparseMoeBlock(nn.Module):
...
@@ -556,7 +513,15 @@ class Glm4MoeSparseMoeBlock(nn.Module):
)
)
if
shared_output
is
not
None
:
if
shared_output
is
not
None
:
final_hidden_states
.
add_
(
shared_output
)
x
=
shared_output
if
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
:
x
.
add_
(
final_hidden_states
)
else
:
x
.
add_
(
final_hidden_states
,
alpha
=
self
.
routed_scaling_factor
)
final_hidden_states
=
x
else
:
if
not
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
:
final_hidden_states
*=
self
.
routed_scaling_factor
return
final_hidden_states
return
final_hidden_states
...
@@ -566,6 +531,82 @@ class Glm4MoeSparseMoeBlock(nn.Module):
...
@@ -566,6 +531,82 @@ class Glm4MoeSparseMoeBlock(nn.Module):
shared_output
=
self
.
shared_experts
(
hidden_states
)
shared_output
=
self
.
shared_experts
(
hidden_states
)
return
shared_output
return
shared_output
def
op_gate
(
self
,
state
):
if
is_non_idle_and_non_empty
(
state
.
forward_batch
.
forward_mode
,
state
.
hidden_states_mlp_input
):
# router_logits: (num_tokens, n_experts)
state
.
router_logits
=
self
.
gate
(
state
.
hidden_states_mlp_input
)
else
:
state
.
router_logits
=
None
def
op_select_experts
(
self
,
state
):
router_logits
=
state
.
pop
(
"router_logits"
)
hidden_states
=
state
.
hidden_states_mlp_input
if
router_logits
is
not
None
:
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
):
state
.
topk_output
=
self
.
topk
(
hidden_states
=
hidden_states
,
router_logits
=
router_logits
,
num_token_non_padded
=
state
.
forward_batch
.
num_token_non_padded
,
expert_location_dispatch_info
=
ExpertLocationDispatchInfo
.
init_new
(
layer_id
=
self
.
layer_id
,
),
)
else
:
state
.
topk_output
=
self
.
topk
.
empty_topk_output
(
hidden_states
.
device
)
def
op_dispatch_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
self
.
experts
.
dispatcher
.
dispatch_a
(
hidden_states
=
state
.
hidden_states_mlp_input
,
topk_output
=
state
.
pop
(
"topk_output"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_dispatch_b
(
self
,
state
):
if
self
.
ep_size
>
1
:
with
get_global_expert_distribution_recorder
().
with_current_layer
(
self
.
layer_id
):
state
.
dispatch_output
=
self
.
experts
.
dispatcher
.
dispatch_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_experts
(
self
,
state
):
state
.
combine_input
=
self
.
experts
.
run_moe_core
(
dispatch_output
=
state
.
dispatch_output
,
)
def
op_combine_a
(
self
,
state
):
if
self
.
ep_size
>
1
:
self
.
experts
.
dispatcher
.
combine_a
(
combine_input
=
state
.
pop
(
"combine_input"
),
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
state
.
pop
(
"dispatch_output"
)
def
op_combine_b
(
self
,
state
):
if
self
.
ep_size
>
1
:
state
.
hidden_states_after_combine
=
self
.
experts
.
dispatcher
.
combine_b
(
tbo_subbatch_index
=
state
.
get
(
"tbo_subbatch_index"
),
)
def
op_output
(
self
,
state
):
final_hidden_states
=
state
.
pop
(
"hidden_states_after_combine"
)
if
(
shared_output
:
=
state
.
pop
(
"shared_output"
))
is
not
None
:
x
=
shared_output
x
.
add_
(
final_hidden_states
,
alpha
=
self
.
routed_scaling_factor
)
final_hidden_states
=
x
else
:
final_hidden_states
*=
self
.
routed_scaling_factor
state
.
hidden_states_mlp_output
=
final_hidden_states
class
Glm4MoeDecoderLayer
(
nn
.
Module
):
class
Glm4MoeDecoderLayer
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -670,6 +711,7 @@ class Glm4MoeDecoderLayer(nn.Module):
...
@@ -670,6 +711,7 @@ class Glm4MoeDecoderLayer(nn.Module):
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
=
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
,
forward_batch
)
)
...
@@ -684,14 +726,96 @@ class Glm4MoeDecoderLayer(nn.Module):
...
@@ -684,14 +726,96 @@ class Glm4MoeDecoderLayer(nn.Module):
hidden_states
,
residual
,
forward_batch
hidden_states
,
residual
,
forward_batch
)
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
)
should_allreduce_fusion
=
(
self
.
layer_communicator
.
should_fuse_mlp_allreduce_with_next_layer
(
forward_batch
)
)
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
# For DP with padding, reduce scatter can be used instead of all-reduce.
hidden_states
,
residual
,
forward_batch
use_reduce_scatter
=
self
.
layer_communicator
.
should_use_reduce_scatter
(
forward_batch
)
hidden_states
=
self
.
mlp
(
hidden_states
,
forward_batch
,
should_allreduce_fusion
,
use_reduce_scatter
)
)
if
should_allreduce_fusion
:
hidden_states
.
_sglang_needs_allreduce_fusion
=
True
else
:
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
hidden_states
,
residual
,
forward_batch
)
return
hidden_states
,
residual
return
hidden_states
,
residual
def
op_comm_prepare_attn
(
self
,
state
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
residual
:
Optional
[
torch
.
Tensor
],
tbo_subbatch_index
:
Optional
[
int
]
=
None
,
):
state
.
hidden_states_after_comm_pre_attn
,
state
.
residual_after_input_ln
=
(
self
.
layer_communicator
.
prepare_attn
(
hidden_states
,
residual
,
forward_batch
)
)
state
.
update
(
dict
(
forward_batch
=
forward_batch
,
positions
=
positions
,
tbo_subbatch_index
=
tbo_subbatch_index
,
)
)
def
op_comm_prepare_mlp
(
self
,
state
):
state
.
hidden_states_mlp_input
,
state
.
residual_after_comm_pre_mlp
=
(
self
.
layer_communicator
.
prepare_mlp
(
state
.
pop
(
"hidden_states_after_attn"
),
state
.
pop
(
"residual_after_input_ln"
),
state
.
forward_batch
,
)
)
def
op_mlp
(
self
,
state
):
hidden_states
=
state
.
pop
(
"hidden_states_mlp_input"
)
if
not
(
enable_moe_dense_fully_dp
()
and
(
not
self
.
is_layer_sparse
)
and
hidden_states
.
shape
[
0
]
==
0
):
state
.
hidden_states_mlp_output
=
self
.
mlp
(
hidden_states
,
state
.
forward_batch
)
else
:
state
.
hidden_states_mlp_output
=
hidden_states
def
op_comm_postprocess_layer
(
self
,
state
):
hidden_states
,
residual
=
self
.
layer_communicator
.
postprocess_layer
(
state
.
pop
(
"hidden_states_mlp_output"
),
state
.
pop
(
"residual_after_comm_pre_mlp"
),
state
.
forward_batch
,
)
output
=
dict
(
positions
=
state
.
positions
,
hidden_states
=
hidden_states
,
residual
=
residual
,
forward_batch
=
state
.
forward_batch
,
tbo_subbatch_index
=
state
.
tbo_subbatch_index
,
)
state
.
clear
(
expect_keys
=
{
"positions"
,
"forward_batch"
,
"tbo_subbatch_index"
,
}
)
return
output
class
Glm4MoeModel
(
nn
.
Module
):
class
Glm4MoeModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -704,6 +828,7 @@ class Glm4MoeModel(nn.Module):
...
@@ -704,6 +828,7 @@ class Glm4MoeModel(nn.Module):
self
.
pp_group
=
get_pp_group
()
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
vocab_size
=
config
.
vocab_size
self
.
vocab_size
=
config
.
vocab_size
self
.
first_k_dense_replace
=
config
.
first_k_dense_replace
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
if
self
.
pp_group
.
is_first_rank
:
if
self
.
pp_group
.
is_first_rank
:
self
.
embed_tokens
=
VocabParallelEmbedding
(
self
.
embed_tokens
=
VocabParallelEmbedding
(
...
@@ -733,6 +858,8 @@ class Glm4MoeModel(nn.Module):
...
@@ -733,6 +858,8 @@ class Glm4MoeModel(nn.Module):
else
:
else
:
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
self
.
norm
=
PPMissingLayer
(
return_tuple
=
True
)
self
.
layers_to_capture
=
[]
def
get_input_embeddings
(
self
)
->
torch
.
Tensor
:
def
get_input_embeddings
(
self
)
->
torch
.
Tensor
:
return
self
.
embed_tokens
return
self
.
embed_tokens
...
@@ -766,8 +893,11 @@ class Glm4MoeModel(nn.Module):
...
@@ -766,8 +893,11 @@ class Glm4MoeModel(nn.Module):
elif
self
.
first_k_dense_replace
<
normal_start_layer
:
elif
self
.
first_k_dense_replace
<
normal_start_layer
:
normal_end_layer
=
normal_start_layer
=
0
normal_end_layer
=
normal_start_layer
=
0
aux_hidden_states
=
[]
for
i
in
range
(
normal_start_layer
,
normal_end_layer
):
for
i
in
range
(
normal_start_layer
,
normal_end_layer
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
with
get_global_expert_distribution_recorder
().
with_current_layer
(
i
):
if
i
in
self
.
layers_to_capture
:
aux_hidden_states
.
append
(
hidden_states
+
residual
)
layer
=
self
.
layers
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
hidden_states
,
residual
=
layer
(
positions
,
positions
,
...
@@ -802,7 +932,9 @@ class Glm4MoeModel(nn.Module):
...
@@ -802,7 +932,9 @@ class Glm4MoeModel(nn.Module):
hidden_states
=
self
.
norm
(
hidden_states
)
hidden_states
=
self
.
norm
(
hidden_states
)
else
:
else
:
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
if
len
(
aux_hidden_states
)
==
0
:
return
hidden_states
return
hidden_states
return
hidden_states
,
aux_hidden_states
class
Glm4MoeForCausalLM
(
nn
.
Module
):
class
Glm4MoeForCausalLM
(
nn
.
Module
):
...
@@ -813,10 +945,10 @@ class Glm4MoeForCausalLM(nn.Module):
...
@@ -813,10 +945,10 @@ class Glm4MoeForCausalLM(nn.Module):
prefix
:
str
=
""
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
nn
.
Module
.
__init__
(
self
)
nn
.
Module
.
__init__
(
self
)
self
.
pp_group
=
get_pp_group
()
self
.
config
=
config
self
.
config
=
config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
pp_group
=
get_pp_group
()
self
.
model
=
Glm4MoeModel
(
self
.
model
=
Glm4MoeModel
(
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
config
,
quant_config
,
prefix
=
add_prefix
(
"model"
,
prefix
)
)
)
...
@@ -847,10 +979,13 @@ class Glm4MoeForCausalLM(nn.Module):
...
@@ -847,10 +979,13 @@ class Glm4MoeForCausalLM(nn.Module):
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
input_ids
,
positions
,
forward_batch
,
input_embeds
,
pp_proxy_tensors
)
)
aux_hidden_states
=
None
if
self
.
capture_aux_hidden_states
:
hidden_states
,
aux_hidden_states
=
hidden_states
if
self
.
pp_group
.
is_last_rank
:
if
self
.
pp_group
.
is_last_rank
:
return
self
.
logits_processor
(
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
input_ids
,
hidden_states
,
self
.
lm_head
,
forward_batch
,
aux_hidden_states
)
)
else
:
else
:
return
hidden_states
return
hidden_states
...
@@ -1027,5 +1162,19 @@ class Glm4MoeForCausalLM(nn.Module):
...
@@ -1027,5 +1162,19 @@ class Glm4MoeForCausalLM(nn.Module):
num_groups
=
config
.
n_group
,
num_groups
=
config
.
n_group
,
)
)
def
set_eagle3_layers_to_capture
(
self
,
layer_ids
:
Optional
[
List
[
int
]]
=
None
):
if
not
self
.
pp_group
.
is_last_rank
:
return
if
layer_ids
is
None
:
self
.
capture_aux_hidden_states
=
True
num_layers
=
self
.
config
.
num_hidden_layers
self
.
model
.
layers_to_capture
=
[
2
,
num_layers
//
2
,
num_layers
-
3
]
else
:
self
.
capture_aux_hidden_states
=
True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self
.
model
.
layers_to_capture
=
[
val
+
1
for
val
in
layer_ids
]
EntryClass
=
[
Glm4MoeForCausalLM
]
EntryClass
=
[
Glm4MoeForCausalLM
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment