Unverified Commit 75c7ad99 authored by zhrrr's avatar zhrrr Committed by GitHub
Browse files

[Kernel][Performance] Fuse float cast and renormalize to topk softmax kernel (#26717)


Signed-off-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: default avatarizhuhaoran <izhuhaoran@qq.com>
parent 5550ff9c
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices, void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices, torch::Tensor& token_expert_indices,
torch::Tensor& gating_output); torch::Tensor& gating_output, bool renormalize);
void moe_sum(torch::Tensor& input, torch::Tensor& output); void moe_sum(torch::Tensor& input, torch::Tensor& output);
......
This diff is collapsed.
...@@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs. // Apply topk softmax to the gating outputs.
m.def( m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()"); "token_expert_indices, Tensor gating_output, bool renormalize) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax); m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Calculate the result of moe by summing up the partial results // Calculate the result of moe by summing up the partial results
......
...@@ -1850,9 +1850,10 @@ def topk_softmax( ...@@ -1850,9 +1850,10 @@ def topk_softmax(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor, token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
renormalize: bool = False,
) -> None: ) -> None:
torch.ops._moe_C.topk_softmax( torch.ops._moe_C.topk_softmax(
topk_weights, topk_ids, token_expert_indices, gating_output topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
) )
......
...@@ -1074,9 +1074,8 @@ def vllm_topk_softmax( ...@@ -1074,9 +1074,8 @@ def vllm_topk_softmax(
topk_indices, topk_indices,
token_expert_indices, token_expert_indices,
gating_output, gating_output,
renormalize,
) )
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices return topk_weights, topk_indices
...@@ -1113,11 +1112,9 @@ def fused_topk( ...@@ -1113,11 +1112,9 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device M, topk, dtype=torch.int32, device=hidden_states.device
) )
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
topk_func = dispatch_topk_func() topk_func = dispatch_topk_func()
topk_weights, topk_ids = topk_func( topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
) )
return topk_weights, topk_ids, token_expert_indices return topk_weights, topk_ids, token_expert_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment