Unverified Commit 5b6acc14 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

fix glm4 moe (#8883)

parent 4373df55
...@@ -527,7 +527,10 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -527,7 +527,10 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep() self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
def forward_normal_dual_stream( def forward_normal_dual_stream(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
current_stream = torch.cuda.current_stream() current_stream = torch.cuda.current_stream()
...@@ -548,21 +551,32 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE): ...@@ -548,21 +551,32 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
current_stream.wait_stream(self.alt_stream) current_stream.wait_stream(self.alt_stream)
if self.ep_size > 1: if self.ep_size > 1:
if self.tp_size > 1 and not can_fuse_mlp_allreduce: if (
self.tp_size > 1
and not can_fuse_mlp_allreduce
and not use_reduce_scatter
):
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states final_hidden_states
) )
final_hidden_states += shared_output final_hidden_states += shared_output
else: else:
final_hidden_states += shared_output final_hidden_states += shared_output
if self.tp_size > 1 and not can_fuse_mlp_allreduce: if (
self.tp_size > 1
and not can_fuse_mlp_allreduce
and not use_reduce_scatter
):
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states final_hidden_states
) )
return final_hidden_states return final_hidden_states
def forward_normal( def forward_normal(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False self,
hidden_states: torch.Tensor,
can_fuse_mlp_allreduce: bool = False,
use_reduce_scatter: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend( if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj self.shared_experts.gate_up_proj
...@@ -681,6 +695,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer): ...@@ -681,6 +695,7 @@ class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
layer_scatter_modes=self.layer_scatter_modes, layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm, input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm, post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
) )
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment