Unverified Commit 534c45b9 authored by ZiTian Zhao's avatar ZiTian Zhao Committed by GitHub
Browse files

Improve fast_topk function with type hints and documentation (#22530)


Signed-off-by: default avatarzitian.zhao <zitian.zhao@tencentmusic.com>
parent 3d7363e6
......@@ -736,7 +736,23 @@ def cast_overflow_tensors(
return tensors
def fast_topk(values, topk, dim):
def fast_topk(values: torch.Tensor, topk: int,
dim: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Optimized topk implementation that uses torch.max for k=1 case.
This function provides better performance for the common case of k=1
by using torch.max instead of the more general torch.topk.
Args:
values: Input tensor to find top-k values from
topk: Number of top values to return (k). Must be > 0.
dim: Dimension along which to compute topk
Returns:
Tuple of (values, indices) where values are the top-k values
and indices are their corresponding indices in the input tensor
"""
if topk == 1:
# Use max along the specified dimension to get both value and index
return torch.max(values, dim=dim, keepdim=True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment