Unverified Commit 4e67a8f6 authored by Yakine Tahtah's avatar Yakine Tahtah Committed by GitHub
Browse files

[Bugfix] Fix GLM-4 MoE router logits dtype for data parallel chunking (#31055)


Signed-off-by: default avatarReinforcedKnowledge <reinforced.knowledge@gmail.com>
parent 142c4d17
......@@ -1006,6 +1006,9 @@ class FusedMoEConfig:
# The activation type.
in_dtype: torch.dtype
# Defaults to in_dtype if not specified.
router_logits_dtype: torch.dtype | None = None
max_num_tokens: int = envs.VLLM_MOE_DP_CHUNK_SIZE
has_bias: bool = False
......@@ -1022,6 +1025,9 @@ class FusedMoEConfig:
assert self.max_num_tokens > 0
if self.router_logits_dtype is None:
self.router_logits_dtype = self.in_dtype
@property
def tp_size(self):
return self.moe_parallel_config.tp_size
......
......@@ -314,6 +314,7 @@ class FusedMoE(CustomOp):
renormalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer.
router_logits_dtype: Data type for router logits buffers.
"""
def __init__(
......@@ -348,6 +349,7 @@ class FusedMoE(CustomOp):
expert_mapping: list[tuple[str, str, int, str]] | None = None,
n_shared_experts: int | None = None,
routing_method_type: int | None = None,
router_logits_dtype: torch.dtype | None = None,
):
super().__init__()
......@@ -559,6 +561,7 @@ class FusedMoE(CustomOp):
num_local_experts=self.local_num_experts,
moe_parallel_config=self.moe_parallel_config,
in_dtype=moe_in_dtype,
router_logits_dtype=router_logits_dtype,
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
......@@ -1509,7 +1512,9 @@ class FusedMoE(CustomOp):
)
self.batched_router_logits = torch.zeros(
logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device()
logits_shape,
dtype=moe.router_logits_dtype,
device=torch.cuda.current_device(),
)
def select_experts(
......
......@@ -197,6 +197,7 @@ class Glm4MoE(nn.Module):
e_score_correction_bias=self.gate.e_score_correction_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment