Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a2a5f79e
Unverified
Commit
a2a5f79e
authored
Sep 19, 2025
by
qizixi
Committed by
GitHub
Sep 19, 2025
Browse files
Optimize triton unified attention performance for sliding window attention (#24390)
Signed-off-by:
zixi-qi
<
qizixi@meta.com
>
parent
c59a0eca
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
3 deletions
+25
-3
tests/kernels/attention/test_triton_unified_attention.py
tests/kernels/attention/test_triton_unified_attention.py
+1
-1
vllm/attention/ops/triton_unified_attention.py
vllm/attention/ops/triton_unified_attention.py
+24
-2
No files found.
tests/kernels/attention/test_triton_unified_attention.py
View file @
a2a5f79e
...
...
@@ -83,7 +83,7 @@ def ref_paged_attn(
@
pytest
.
mark
.
parametrize
(
"num_heads"
,
NUM_HEADS
)
@
pytest
.
mark
.
parametrize
(
"head_size"
,
HEAD_SIZES
)
@
pytest
.
mark
.
parametrize
(
"block_size"
,
BLOCK_SIZES
)
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
256
])
@
pytest
.
mark
.
parametrize
(
"sliding_window"
,
[
None
,
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
DTYPES
)
@
pytest
.
mark
.
parametrize
(
"soft_cap"
,
[
None
,
50.0
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
NUM_BLOCKS
)
...
...
vllm/attention/ops/triton_unified_attention.py
View file @
a2a5f79e
...
...
@@ -184,8 +184,30 @@ def kernel_unified_attention_2d(
# this prefix can be skipped)
num_tiles
=
cdiv_fn
(
max_seq_prefix_len
,
TILE_SIZE
)
# iterate through tiles
for
j
in
range
(
0
,
num_tiles
):
# ---- Sliding-window tile pruning --------------------
# Default: keep previous global behavior
tile_start
=
0
tile_end
=
num_tiles
if
SLIDING_WINDOW
>
0
:
# Query rows covered by this Q-block
qpos_lo
=
q_block_local_idx
*
BLOCK_Q
qpos_hi
=
tl
.
minimum
(
qpos_lo
+
(
BLOCK_M
-
1
)
//
num_queries_per_kv
,
cur_batch_query_len
-
1
,
)
# For sliding window, each query position q can only attend to
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
# where q_abs = context_len + q
# The union of allowed key positions for this Q-block is:
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
first_allowed_key
=
context_len
+
qpos_lo
-
SLIDING_WINDOW
+
1
last_allowed_key
=
context_len
+
qpos_hi
# Convert to tile indices and clamp
tile_start
=
tl
.
maximum
(
0
,
first_allowed_key
//
TILE_SIZE
)
tile_end
=
tl
.
minimum
((
last_allowed_key
//
TILE_SIZE
)
+
1
,
num_tiles
)
# iterate through tiles (now limited to the sliding window range)
for
j
in
range
(
tile_start
,
tile_end
):
seq_offset
=
j
*
TILE_SIZE
+
offs_t
tile_mask
=
seq_offset
<
max_seq_prefix_len
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment