Unverified Commit eb5ed207 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Define router_logits_dtype for remaining MoE models (#33737)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 26471636
...@@ -142,6 +142,7 @@ class AfmoeMoE(nn.Module): ...@@ -142,6 +142,7 @@ class AfmoeMoE(nn.Module):
e_score_correction_bias=self.expert_bias, e_score_correction_bias=self.expert_bias,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
...@@ -300,6 +300,7 @@ class BailingMoE(nn.Module): ...@@ -300,6 +300,7 @@ class BailingMoE(nn.Module):
num_expert_group=self.n_group, num_expert_group=self.n_group,
topk_group=self.topk_group, topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk, use_grouped_topk=self.use_grouped_topk,
router_logits_dtype=self.router_dtype,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
...@@ -71,7 +71,6 @@ class FlexOlmoMoE(nn.Module): ...@@ -71,7 +71,6 @@ class FlexOlmoMoE(nn.Module):
prefix=f"{prefix}.gate", prefix=f"{prefix}.gate",
) )
# Gate always runs at half / full precision for now.
self.experts = FusedMoE( self.experts = FusedMoE(
num_experts=hf_config.num_experts, num_experts=hf_config.num_experts,
top_k=hf_config.num_experts_per_tok, top_k=hf_config.num_experts_per_tok,
...@@ -82,6 +81,7 @@ class FlexOlmoMoE(nn.Module): ...@@ -82,6 +81,7 @@ class FlexOlmoMoE(nn.Module):
quant_config=None, quant_config=None,
tp_size=tp_size, tp_size=tp_size,
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
router_logits_dtype=torch.float32,
) )
self.top_k = hf_config.num_experts_per_tok self.top_k = hf_config.num_experts_per_tok
......
...@@ -236,9 +236,9 @@ class FlashMLP(nn.Module): ...@@ -236,9 +236,9 @@ class FlashMLP(nn.Module):
class LongcatRouter(nn.Module): class LongcatRouter(nn.Module):
def __init__( def __init__(
self, self,
config, config: FlashConfig,
zero_expert_num=0, zero_expert_num: int,
rounter_params_dtype=torch.bfloat16, rounter_params_dtype: torch.dtype,
prefix: str = "", prefix: str = "",
): ):
super().__init__() super().__init__()
...@@ -309,6 +309,7 @@ class LongcatMoe(nn.Module): ...@@ -309,6 +309,7 @@ class LongcatMoe(nn.Module):
prefix=f"{prefix}.experts", prefix=f"{prefix}.experts",
enable_eplb=enable_eplb, enable_eplb=enable_eplb,
routed_scaling_factor=config.routed_scaling_factor, routed_scaling_factor=config.routed_scaling_factor,
router_logits_dtype=self.rounter_params_dtype,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
...@@ -174,6 +174,7 @@ class MiMoV2MoE(nn.Module): ...@@ -174,6 +174,7 @@ class MiMoV2MoE(nn.Module):
num_expert_group=config.n_group, num_expert_group=config.n_group,
topk_group=config.topk_group, topk_group=config.topk_group,
scoring_func="sigmoid", scoring_func="sigmoid",
router_logits_dtype=self.gate_dtype,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
...@@ -388,6 +388,7 @@ class FusedMoEBlock(nn.Module): ...@@ -388,6 +388,7 @@ class FusedMoEBlock(nn.Module):
routed_scaling_factor=config.moe_router_scaling_factor, routed_scaling_factor=config.moe_router_scaling_factor,
enable_eplb=self.enable_eplb, enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts, num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
) )
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment