Unverified Commit 5b6acc14 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

fix glm4 moe (#8883)

parent 4373df55
......@@ -527,7 +527,10 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
def forward_normal_dual_stream(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
......@@ -548,21 +551,32 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
current_stream.wait_stream(self.alt_stream)
if self.ep_size > 1:
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if (
self.tp_size > 1
and not can_fuse_mlp_allreduce
and not use_reduce_scatter
):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
final_hidden_states += shared_output
else:
final_hidden_states += shared_output
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
if (
self.tp_size > 1
and not can_fuse_mlp_allreduce
and not use_reduce_scatter
):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
return final_hidden_states
def forward_normal(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj
......@@ -681,6 +695,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
)
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment