Unverified Commit b102353f authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[MoE] Enable `renormalize=False` in Triton kernels (#8735)

parent 7a27e798
...@@ -183,15 +183,13 @@ class TopK(CustomOp): ...@@ -183,15 +183,13 @@ class TopK(CustomOp):
*, *,
num_token_non_padded: Optional[torch.Tensor] = None, num_token_non_padded: Optional[torch.Tensor] = None,
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
sm_first: bool = False, # only used for triton kernels topk
) -> TopKOutput: ) -> TopKOutput:
if self.use_triton_kernels: if self.use_triton_kernels:
return triton_kernels_topk( # renormalize=True is equivalent to sm_first=False
router_logits=router_logits, routing_data, gather_idx, scatter_idx = routing(
topk=self.top_k, router_logits, self.top_k, sm_first=not self.renormalize
renormalize=self.renormalize,
sm_first=sm_first,
) )
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
else: else:
torch_native = False torch_native = False
return select_experts( return select_experts(
...@@ -647,22 +645,6 @@ def biased_grouped_topk_cpu( ...@@ -647,22 +645,6 @@ def biased_grouped_topk_cpu(
) )
def triton_kernels_topk(
router_logits: torch.Tensor,
topk: int,
renormalize: bool = False,
sm_first: bool = False,
) -> TritonKernelTopKOutput:
"""Top-K routing for Triton kernels MoE."""
assert not renormalize, "Triton kernels topk doesn't support renormalize"
routing_data, gather_idx, scatter_idx = routing(
logits=router_logits,
n_expts_act=topk,
sm_first=sm_first,
)
return TritonKernelTopKOutput(routing_data, gather_idx, scatter_idx)
if _is_cpu and _is_cpu_amx_available: if _is_cpu and _is_cpu_amx_available:
biased_grouped_topk = biased_grouped_topk_cpu biased_grouped_topk = biased_grouped_topk_cpu
grouped_topk = grouped_topk_cpu grouped_topk = grouped_topk_cpu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment