Unverified Commit 2ae95d17 authored by Minglei Zhu's avatar Minglei Zhu Committed by GitHub
Browse files

Disable tp for shared experts under expert parallelism for GLM4.5 model (#8647) (#8647)


Co-authored-by: default avatarStefan He <hebiaobuaa@gmail.com>
Co-authored-by: default avatarCheng Wan <54331508+ch-wan@users.noreply.github.com>
parent 2d401bd9
......@@ -387,6 +387,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
):
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
self.ep_size = get_moe_expert_parallel_world_size()
self.routed_scaling_factor = config.routed_scaling_factor
self.n_shared_experts = config.n_shared_experts
self.num_fused_shared_experts = (
......@@ -480,11 +481,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
**(
dict(tp_rank=0, tp_size=1)
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
**(dict(tp_rank=0, tp_size=1) if self.ep_size > 1 else {}),
)
is_packed_weight = hasattr(
self.shared_experts.gate_up_proj.quant_method, "quant_config"
......@@ -531,6 +528,77 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
self._enable_deepep_moe = global_server_args_dict["moe_a2a_backend"].is_deepep()
def forward_normal_dual_stream(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states)
with torch.cuda.stream(self.alt_stream):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs)
if not _is_cuda:
final_hidden_states *= self.routed_scaling_factor
current_stream.wait_stream(self.alt_stream)
if self.ep_size > 1:
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
final_hidden_states += shared_output
else:
final_hidden_states += shared_output
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
return final_hidden_states
def forward_normal(
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
) -> torch.Tensor:
if hasattr(self, "shared_experts") and use_intel_amx_backend(
self.shared_experts.gate_up_proj
):
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
shared_output = self._forward_shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs)
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if self.ep_size > 1:
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
if shared_output is not None:
final_hidden_states += shared_output
else:
if shared_output is not None:
final_hidden_states += shared_output
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states
)
return final_hidden_states
class Glm4MoeDecoderLayer(DeepseekV2DecoderLayer):
def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment