Unverified Commit eb5ed207 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Define router_logits_dtype for remaining MoE models (#33737)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 26471636
......@@ -142,6 +142,7 @@ class AfmoeMoE(nn.Module):
e_score_correction_bias=self.expert_bias,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
......@@ -300,6 +300,7 @@ class BailingMoE(nn.Module):
num_expert_group=self.n_group,
topk_group=self.topk_group,
use_grouped_topk=self.use_grouped_topk,
router_logits_dtype=self.router_dtype,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
......@@ -71,7 +71,6 @@ class FlexOlmoMoE(nn.Module):
prefix=f"{prefix}.gate",
)
# Gate always runs at half / full precision for now.
self.experts = FusedMoE(
num_experts=hf_config.num_experts,
top_k=hf_config.num_experts_per_tok,
......@@ -82,6 +81,7 @@ class FlexOlmoMoE(nn.Module):
quant_config=None,
tp_size=tp_size,
prefix=f"{prefix}.experts",
router_logits_dtype=torch.float32,
)
self.top_k = hf_config.num_experts_per_tok
......
......@@ -236,9 +236,9 @@ class FlashMLP(nn.Module):
class LongcatRouter(nn.Module):
def __init__(
self,
config,
zero_expert_num=0,
rounter_params_dtype=torch.bfloat16,
config: FlashConfig,
zero_expert_num: int,
rounter_params_dtype: torch.dtype,
prefix: str = "",
):
super().__init__()
......@@ -309,6 +309,7 @@ class LongcatMoe(nn.Module):
prefix=f"{prefix}.experts",
enable_eplb=enable_eplb,
routed_scaling_factor=config.routed_scaling_factor,
router_logits_dtype=self.rounter_params_dtype,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
......@@ -174,6 +174,7 @@ class MiMoV2MoE(nn.Module):
num_expert_group=config.n_group,
topk_group=config.topk_group,
scoring_func="sigmoid",
router_logits_dtype=self.gate_dtype,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
......@@ -388,6 +388,7 @@ class FusedMoEBlock(nn.Module):
routed_scaling_factor=config.moe_router_scaling_factor,
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts,
router_logits_dtype=torch.float32,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment