Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5b6acc14
Unverified
Commit
5b6acc14
authored
Aug 06, 2025
by
Cheng Wan
Committed by
GitHub
Aug 06, 2025
Browse files
fix glm4 moe (#8883)
parent
4373df55
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
4 deletions
+19
-4
python/sglang/srt/models/glm4_moe.py
python/sglang/srt/models/glm4_moe.py
+19
-4
No files found.
python/sglang/srt/models/glm4_moe.py
View file @
5b6acc14
...
@@ -527,7 +527,10 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -527,7 +527,10 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self
.
_enable_deepep_moe
=
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
()
self
.
_enable_deepep_moe
=
global_server_args_dict
[
"moe_a2a_backend"
].
is_deepep
()
def
forward_normal_dual_stream
(
def
forward_normal_dual_stream
(
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp_allreduce
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp_allreduce
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
current_stream
=
torch
.
cuda
.
current_stream
()
current_stream
=
torch
.
cuda
.
current_stream
()
...
@@ -548,21 +551,32 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
...
@@ -548,21 +551,32 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
current_stream
.
wait_stream
(
self
.
alt_stream
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
if
self
.
ep_size
>
1
:
if
self
.
ep_size
>
1
:
if
self
.
tp_size
>
1
and
not
can_fuse_mlp_allreduce
:
if
(
self
.
tp_size
>
1
and
not
can_fuse_mlp_allreduce
and
not
use_reduce_scatter
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
final_hidden_states
)
)
final_hidden_states
+=
shared_output
final_hidden_states
+=
shared_output
else
:
else
:
final_hidden_states
+=
shared_output
final_hidden_states
+=
shared_output
if
self
.
tp_size
>
1
and
not
can_fuse_mlp_allreduce
:
if
(
self
.
tp_size
>
1
and
not
can_fuse_mlp_allreduce
and
not
use_reduce_scatter
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
final_hidden_states
)
)
return
final_hidden_states
return
final_hidden_states
def
forward_normal
(
def
forward_normal
(
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp_allreduce
:
bool
=
False
self
,
hidden_states
:
torch
.
Tensor
,
can_fuse_mlp_allreduce
:
bool
=
False
,
use_reduce_scatter
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
hasattr
(
self
,
"shared_experts"
)
and
use_intel_amx_backend
(
if
hasattr
(
self
,
"shared_experts"
)
and
use_intel_amx_backend
(
self
.
shared_experts
.
gate_up_proj
self
.
shared_experts
.
gate_up_proj
...
@@ -681,6 +695,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
...
@@ -681,6 +695,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
layer_scatter_modes
=
self
.
layer_scatter_modes
,
layer_scatter_modes
=
self
.
layer_scatter_modes
,
input_layernorm
=
self
.
input_layernorm
,
input_layernorm
=
self
.
input_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
post_attention_layernorm
=
self
.
post_attention_layernorm
,
allow_reduce_scatter
=
True
,
)
)
def
forward
(
def
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment