"...csrc/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "b93d5ee2ddcf2a9876b5871cbb958016e263336b"
Unverified Commit a589a071 authored by Atream's avatar Atream Committed by GitHub
Browse files

fix moe gate dtype, fix tbo, fix fake dispatch (#7825)

parent f62d75b6
...@@ -66,7 +66,7 @@ def transform_select_experts_inputs( ...@@ -66,7 +66,7 @@ def transform_select_experts_inputs(
info: Optional[ExpertLocationDispatchInfo], info: Optional[ExpertLocationDispatchInfo],
): ):
if (info is not None) and (info.ep_dispatch_algorithm == "fake"): if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
router_logits = torch.randn_like(router_logits) router_logits.uniform_(5, 10)
if correction_bias is not None: if correction_bias is not None:
correction_bias = torch.zeros_like(correction_bias) correction_bias = torch.zeros_like(correction_bias)
return router_logits, correction_bias return router_logits, correction_bias
......
...@@ -499,7 +499,7 @@ def biased_grouped_topk_gpu( ...@@ -499,7 +499,7 @@ def biased_grouped_topk_gpu(
and is_power_of_two(correction_bias.shape[0]) and is_power_of_two(correction_bias.shape[0])
): ):
topk_weights, topk_ids = moe_fused_gate( topk_weights, topk_ids = moe_fused_gate(
gating_output, gating_output.to(dtype=torch.float32),
correction_bias, correction_bias,
num_expert_group, num_expert_group,
topk_group, topk_group,
......
...@@ -229,7 +229,7 @@ class MoEGate(nn.Module): ...@@ -229,7 +229,7 @@ class MoEGate(nn.Module):
) )
if config.topk_method == "noaux_tc": if config.topk_method == "noaux_tc":
self.e_score_correction_bias = nn.Parameter( self.e_score_correction_bias = nn.Parameter(
torch.empty((config.n_routed_experts)) torch.empty((config.n_routed_experts), dtype=torch.float32)
) )
else: else:
self.e_score_correction_bias = None self.e_score_correction_bias = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment