Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
51b2333b
Unverified
Commit
51b2333b
authored
Mar 17, 2026
by
Michael Goin
Committed by
GitHub
Mar 17, 2026
Browse files
[Perf] Optimize top-k search in apply_top_k_top_p_triton sampler (#37225)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
4ed51308
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
45 deletions
+63
-45
vllm/v1/sample/ops/topk_topp_triton.py
vllm/v1/sample/ops/topk_topp_triton.py
+63
-45
No files found.
vllm/v1/sample/ops/topk_topp_triton.py
View file @
51b2333b
...
...
@@ -67,6 +67,29 @@ _PERCENTILE_TO_STD_TABLE = [
# fmt: on
@
triton
.
jit
def
_update_min_larger_stats
(
data
,
above_mask
,
min_larger
,
num_min_larger
,
sentinel
):
"""Update running (min, count) of values above a pivot across tiles.
Tracks the smallest value strictly above a pivot and how many times
it occurs. Called once per tile per pivot; the running state is
carried across tiles via `min_larger` / `num_min_larger`.
Merge rule:
- tile min < running min → replace both
- tile min == running min → accumulate count
- tile min > running min → keep running values
"""
tile_min
=
tl
.
min
(
tl
.
where
(
above_mask
,
data
,
sentinel
))
tile_eq
=
above_mask
&
(
tl
.
abs
(
data
-
tile_min
)
<
1e-9
)
tile_cnt
=
tl
.
sum
(
tile_eq
)
is_new
=
tile_min
<
min_larger
is_same
=
tl
.
abs
(
tile_min
-
min_larger
)
<
1e-9
num_min_larger
=
tl
.
where
(
is_new
,
tile_cnt
,
num_min_larger
+
tile_cnt
*
is_same
)
min_larger
=
tl
.
minimum
(
min_larger
,
tile_min
)
return
min_larger
,
num_min_larger
@
triton
.
jit
def
_topk_topp_kernel
(
LOGITS
,
...
...
@@ -188,7 +211,10 @@ def _topk_topp_kernel(
min_larger_1
=
float
(
"inf"
)
num_min_larger_1
=
tl
.
zeros
((),
dtype
=
tl
.
uint32
)
# First pass: Calculate k_pivots_num and min_larger
# Single fused pass: compute k_pivots_num,
# min_larger, and num_min_larger together to avoid
# a second data scan. See _update_min_larger_stats
# for the tile-level merge logic.
for
i
in
range
(
0
,
search_iters
):
offs_n
=
i
*
BLOCK_SIZE_TRUNC
+
tl
.
arange
(
0
,
BLOCK_SIZE_TRUNC
...
...
@@ -198,27 +224,24 @@ def _topk_topp_kernel(
BUFFER_ROW
+
offs_n
,
mask
=
mask_n_2
,
other
=-
float
(
"inf"
)
)
k_pivots_num_0
+=
tl
.
sum
(
logits_blk2
>
k_pivot_0
)
k_pivots_num_1
+=
tl
.
sum
(
logits_blk2
>
k_pivot_1
)
min_larger_0
=
tl
.
minimum
(
min_larger_0
,
tl
.
min
(
logits_blk2
))
min_larger_1
=
tl
.
minimum
(
min_larger_1
,
tl
.
min
(
logits_blk2
))
above_0
=
logits_blk2
>
k_pivot_0
above_1
=
logits_blk2
>
k_pivot_1
k_pivots_num_0
+=
tl
.
sum
(
above_0
)
k_pivots_num_1
+=
tl
.
sum
(
above_1
)
# Second pass: Calculate num_min_larger
for
i
in
range
(
0
,
search_iters
):
offs_n
=
i
*
BLOCK_SIZE_TRUNC
+
tl
.
arange
(
0
,
BLOCK_SIZE_TRUNC
min_larger_0
,
num_min_larger_0
=
_update_min_larger_stats
(
logits_blk2
,
above_0
,
min_larger_0
,
num_min_larger_0
,
float
(
"inf"
),
)
mask_n_2
=
offs_n
<
search_range
logits_blk2
=
tl
.
load
(
BUFFER_ROW
+
offs_n
,
mask
=
mask_n_2
,
other
=-
float
(
"inf"
)
)
num_min_larger_0
+=
tl
.
sum
(
tl
.
abs
(
logits_blk2
-
min_larger_0
)
<
1e-9
)
num_min_larger_1
+=
tl
.
sum
(
tl
.
abs
(
logits_blk2
-
min_larger_1
)
<
1e-9
min_larger_1
,
num_min_larger_1
=
_update_min_larger_stats
(
logits_blk2
,
above_1
,
min_larger_1
,
num_min_larger_1
,
float
(
"inf"
),
)
# Check if any of the pivots satisfy termination condition
...
...
@@ -272,26 +295,8 @@ def _topk_topp_kernel(
min_larger_1
=
float
(
"inf"
)
num_min_larger_1
=
tl
.
zeros
((),
dtype
=
tl
.
uint32
)
# First pass: Calculate k_pivots_num and min_larger
for
i
in
range
(
0
,
NUM_TILES
):
offs_n
=
i
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask_n
=
offs_n
<
VOCAB_SIZE
logits_blk2
=
tl
.
load
(
LOGITS_ROW
+
offs_n
,
mask
=
mask_n
,
other
=-
float
(
"inf"
)
)
k_pivots_num_0
+=
tl
.
sum
(
logits_blk2
>
k_pivot_0
)
k_pivots_num_1
+=
tl
.
sum
(
logits_blk2
>
k_pivot_1
)
# Exclude -inf from min_larger to avoid
# poisoning the convergence check.
finite_blk2
=
tl
.
where
(
logits_blk2
>
-
float
(
"inf"
),
logits_blk2
,
float
(
"inf"
)
)
min_larger_0
=
tl
.
minimum
(
min_larger_0
,
tl
.
min
(
finite_blk2
))
min_larger_1
=
tl
.
minimum
(
min_larger_1
,
tl
.
min
(
finite_blk2
))
# Second pass: Calculate num_min_larger
# Single fused pass over full vocab (same approach
# as the buffer path above).
for
i
in
range
(
0
,
NUM_TILES
):
offs_n
=
i
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask_n
=
offs_n
<
VOCAB_SIZE
...
...
@@ -299,11 +304,24 @@ def _topk_topp_kernel(
LOGITS_ROW
+
offs_n
,
mask
=
mask_n
,
other
=-
float
(
"inf"
)
)
num_min_larger_0
+=
tl
.
sum
(
tl
.
abs
(
logits_blk2
-
min_larger_0
)
<
1e-9
)
num_min_larger_1
+=
tl
.
sum
(
tl
.
abs
(
logits_blk2
-
min_larger_1
)
<
1e-9
above_0
=
logits_blk2
>
k_pivot_0
above_1
=
logits_blk2
>
k_pivot_1
k_pivots_num_0
+=
tl
.
sum
(
above_0
)
k_pivots_num_1
+=
tl
.
sum
(
above_1
)
min_larger_0
,
num_min_larger_0
=
_update_min_larger_stats
(
logits_blk2
,
above_0
,
min_larger_0
,
num_min_larger_0
,
float
(
"inf"
),
)
min_larger_1
,
num_min_larger_1
=
_update_min_larger_stats
(
logits_blk2
,
above_1
,
min_larger_1
,
num_min_larger_1
,
float
(
"inf"
),
)
# Check if any of the pivots satisfy termination condition
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment