Unverified Commit b5545d9d authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

[Bugfix] [Kernel] Triton attention kernels: mask out V blocks that fall...


[Bugfix] [Kernel] Triton attention kernels: mask out V blocks that fall outside sliding window (#30887)
Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent bd2b52fc
......@@ -363,6 +363,12 @@ def kernel_unified_attention_2d(
L = L * alpha + l_j
M = m_j
if SLIDING_WINDOW:
qpos_lo = q_block_local_idx * BLOCK_Q
V = tl.where(
(context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0
)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)
......@@ -678,6 +684,12 @@ def kernel_unified_attention_3d(
L = L * alpha + l_j
M = m_j
if SLIDING_WINDOW:
qpos_lo = q_block_local_idx * BLOCK_Q
V = tl.where(
(context_len + qpos_lo - seq_offset[:, None]) < SLIDING_WINDOW, V, 0.0
)
# acc : (BLOCK_M, HEAD_SIZE_PADDED)
acc += tl.dot(P.to(V.dtype), V)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment