Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b8bf5c45
Unverified
Commit
b8bf5c45
authored
Jan 10, 2026
by
jvlunteren
Committed by
GitHub
Jan 10, 2026
Browse files
[Kernel] Optimize Sliding Window Attention in 3D Triton Kernel (#31984)
Signed-off-by:
Jan van Lunteren
<
jvl@zurich.ibm.com
>
parent
e6c6f2c7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
3 deletions
+26
-3
vllm/v1/attention/ops/triton_unified_attention.py
vllm/v1/attention/ops/triton_unified_attention.py
+26
-3
No files found.
vllm/v1/attention/ops/triton_unified_attention.py
View file @
b8bf5c45
...
...
@@ -545,10 +545,33 @@ def kernel_unified_attention_3d(
# this prefix can be skipped)
num_tiles
=
cdiv_fn
(
max_seq_prefix_len
,
TILE_SIZE
)
# iterate through tiles within current segment
# ---- Sliding-window tile pruning --------------------
# Default: keep previous global behavior
tile_start
=
0
tile_end
=
num_tiles
# TODO(Isotr0py): sliding window pruning with image bidirectional mask
if
SLIDING_WINDOW
>
0
and
not
USE_MM_PREFIX
:
# Query rows covered by this Q-block
qpos_lo
=
q_block_local_idx
*
BLOCK_Q
qpos_hi
=
tl
.
minimum
(
qpos_lo
+
(
BLOCK_M
-
1
)
//
num_queries_per_kv
,
cur_batch_query_len
-
1
,
)
# For sliding window, each query position q can only attend to
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
# where q_abs = context_len + q
# The union of allowed key positions for this Q-block is:
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
first_allowed_key
=
context_len
+
qpos_lo
-
SLIDING_WINDOW
+
1
last_allowed_key
=
context_len
+
qpos_hi
# Convert to tile indices and clamp
tile_start
=
tl
.
maximum
(
0
,
first_allowed_key
//
TILE_SIZE
)
tile_end
=
tl
.
minimum
((
last_allowed_key
//
TILE_SIZE
)
+
1
,
num_tiles
)
# iterate through tiles (now limited to the sliding window range)
for
j
in
range
(
segm_idx
*
tiles_per_segment
,
min
((
segm_idx
+
1
)
*
tiles_per_segment
,
num_tiles
),
max
(
segm_idx
*
tiles_per_segment
,
tile_start
),
min
((
segm_idx
+
1
)
*
tiles_per_segment
,
tile_end
),
):
seq_offset
=
j
*
TILE_SIZE
+
offs_t
tile_mask
=
seq_offset
<
max_seq_prefix_len
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment