Unverified Commit 22dd9c27 authored by jvlunteren's avatar jvlunteren Committed by GitHub
Browse files

[Kernel] Optimize Prefill Attention in Unified Triton Attention Kernel (#20308)


Signed-off-by: default avatarJan van Lunteren <jvl@zurich.ibm.com>
parent a6d795d5
......@@ -145,7 +145,19 @@ def kernel_unified_attention_2d(
mask=query_mask_1,
other=0.0)
num_blocks = cdiv_fn(seq_len, BLOCK_SIZE)
# compute the length of the longest sequence prefix spanned by any
# query token in the current q_block (q_block_local_idx)
max_seq_prefix_len = context_len + q_block_local_idx * BLOCK_Q + (
BLOCK_M - 1) // num_queries_per_kv + 1
# adjust for potential padding in the last q_block by considering the
# actual sequence length
max_seq_prefix_len = tl.minimum(max_seq_prefix_len, seq_len)
# calculate the number of tiles (blocks) that need to be processed to
# cover the longest sequence prefix (due to causal masking, blocks beyond
# this prefix can be skipped)
num_blocks = cdiv_fn(max_seq_prefix_len, BLOCK_SIZE)
# iterate through tiles
for j in range(0, num_blocks):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment