Unverified Commit 75c7ad99 authored by zhrrr's avatar zhrrr Committed by GitHub
Browse files

[Kernel][Performance] Fuse float cast and renormalize to topk softmax kernel (#26717)


Signed-off-by: default avatarzhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: default avatarizhuhaoran <izhuhaoran@qq.com>
parent 5550ff9c
......@@ -4,7 +4,7 @@
void topk_softmax(torch::Tensor& topk_weights, torch::Tensor& topk_indices,
torch::Tensor& token_expert_indices,
torch::Tensor& gating_output);
torch::Tensor& gating_output, bool renormalize);
void moe_sum(torch::Tensor& input, torch::Tensor& output);
......
This diff is collapsed.
......@@ -5,7 +5,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
"token_expert_indices, Tensor gating_output, bool renormalize) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
// Calculate the result of moe by summing up the partial results
......
......@@ -1850,9 +1850,10 @@ def topk_softmax(
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor,
renormalize: bool = False,
) -> None:
torch.ops._moe_C.topk_softmax(
topk_weights, topk_ids, token_expert_indices, gating_output
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
......
......@@ -1074,9 +1074,8 @@ def vllm_topk_softmax(
topk_indices,
token_expert_indices,
gating_output,
renormalize,
)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_indices
......@@ -1113,11 +1112,9 @@ def fused_topk(
M, topk, dtype=torch.int32, device=hidden_states.device
)
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
topk_func = dispatch_topk_func()
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output_float, renormalize
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
)
return topk_weights, topk_ids, token_expert_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment