Unverified Commit 51b2333b authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Perf] Optimize top-k search in apply_top_k_top_p_triton sampler (#37225)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 4ed51308
...@@ -67,6 +67,29 @@ _PERCENTILE_TO_STD_TABLE = [ ...@@ -67,6 +67,29 @@ _PERCENTILE_TO_STD_TABLE = [
# fmt: on # fmt: on
@triton.jit
def _update_min_larger_stats(data, above_mask, min_larger, num_min_larger, sentinel):
"""Update running (min, count) of values above a pivot across tiles.
Tracks the smallest value strictly above a pivot and how many times
it occurs. Called once per tile per pivot; the running state is
carried across tiles via `min_larger` / `num_min_larger`.
Merge rule:
- tile min < running min → replace both
- tile min == running min → accumulate count
- tile min > running min → keep running values
"""
tile_min = tl.min(tl.where(above_mask, data, sentinel))
tile_eq = above_mask & (tl.abs(data - tile_min) < 1e-9)
tile_cnt = tl.sum(tile_eq)
is_new = tile_min < min_larger
is_same = tl.abs(tile_min - min_larger) < 1e-9
num_min_larger = tl.where(is_new, tile_cnt, num_min_larger + tile_cnt * is_same)
min_larger = tl.minimum(min_larger, tile_min)
return min_larger, num_min_larger
@triton.jit @triton.jit
def _topk_topp_kernel( def _topk_topp_kernel(
LOGITS, LOGITS,
...@@ -188,7 +211,10 @@ def _topk_topp_kernel( ...@@ -188,7 +211,10 @@ def _topk_topp_kernel(
min_larger_1 = float("inf") min_larger_1 = float("inf")
num_min_larger_1 = tl.zeros((), dtype=tl.uint32) num_min_larger_1 = tl.zeros((), dtype=tl.uint32)
# First pass: Calculate k_pivots_num and min_larger # Single fused pass: compute k_pivots_num,
# min_larger, and num_min_larger together to avoid
# a second data scan. See _update_min_larger_stats
# for the tile-level merge logic.
for i in range(0, search_iters): for i in range(0, search_iters):
offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(
0, BLOCK_SIZE_TRUNC 0, BLOCK_SIZE_TRUNC
...@@ -198,27 +224,24 @@ def _topk_topp_kernel( ...@@ -198,27 +224,24 @@ def _topk_topp_kernel(
BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf")
) )
k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) above_0 = logits_blk2 > k_pivot_0
k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) above_1 = logits_blk2 > k_pivot_1
k_pivots_num_0 += tl.sum(above_0)
min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) k_pivots_num_1 += tl.sum(above_1)
min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2))
# Second pass: Calculate num_min_larger min_larger_0, num_min_larger_0 = _update_min_larger_stats(
for i in range(0, search_iters): logits_blk2,
offs_n = i * BLOCK_SIZE_TRUNC + tl.arange( above_0,
0, BLOCK_SIZE_TRUNC min_larger_0,
num_min_larger_0,
float("inf"),
) )
mask_n_2 = offs_n < search_range min_larger_1, num_min_larger_1 = _update_min_larger_stats(
logits_blk2 = tl.load( logits_blk2,
BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf") above_1,
) min_larger_1,
num_min_larger_1,
num_min_larger_0 += tl.sum( float("inf"),
tl.abs(logits_blk2 - min_larger_0) < 1e-9
)
num_min_larger_1 += tl.sum(
tl.abs(logits_blk2 - min_larger_1) < 1e-9
) )
# Check if any of the pivots satisfy termination condition # Check if any of the pivots satisfy termination condition
...@@ -272,26 +295,8 @@ def _topk_topp_kernel( ...@@ -272,26 +295,8 @@ def _topk_topp_kernel(
min_larger_1 = float("inf") min_larger_1 = float("inf")
num_min_larger_1 = tl.zeros((), dtype=tl.uint32) num_min_larger_1 = tl.zeros((), dtype=tl.uint32)
# First pass: Calculate k_pivots_num and min_larger # Single fused pass over full vocab (same approach
for i in range(0, NUM_TILES): # as the buffer path above).
offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask_n = offs_n < VOCAB_SIZE
logits_blk2 = tl.load(
LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")
)
k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0)
k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1)
# Exclude -inf from min_larger to avoid
# poisoning the convergence check.
finite_blk2 = tl.where(
logits_blk2 > -float("inf"), logits_blk2, float("inf")
)
min_larger_0 = tl.minimum(min_larger_0, tl.min(finite_blk2))
min_larger_1 = tl.minimum(min_larger_1, tl.min(finite_blk2))
# Second pass: Calculate num_min_larger
for i in range(0, NUM_TILES): for i in range(0, NUM_TILES):
offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask_n = offs_n < VOCAB_SIZE mask_n = offs_n < VOCAB_SIZE
...@@ -299,11 +304,24 @@ def _topk_topp_kernel( ...@@ -299,11 +304,24 @@ def _topk_topp_kernel(
LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf") LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")
) )
num_min_larger_0 += tl.sum( above_0 = logits_blk2 > k_pivot_0
tl.abs(logits_blk2 - min_larger_0) < 1e-9 above_1 = logits_blk2 > k_pivot_1
) k_pivots_num_0 += tl.sum(above_0)
num_min_larger_1 += tl.sum( k_pivots_num_1 += tl.sum(above_1)
tl.abs(logits_blk2 - min_larger_1) < 1e-9
min_larger_0, num_min_larger_0 = _update_min_larger_stats(
logits_blk2,
above_0,
min_larger_0,
num_min_larger_0,
float("inf"),
)
min_larger_1, num_min_larger_1 = _update_min_larger_stats(
logits_blk2,
above_1,
min_larger_1,
num_min_larger_1,
float("inf"),
) )
# Check if any of the pivots satisfy termination condition # Check if any of the pivots satisfy termination condition
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment