Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c363ed6
Unverified
Commit
8c363ed6
authored
Nov 30, 2025
by
Pleaplusone
Committed by
GitHub
Nov 30, 2025
Browse files
[ROCm][Attention] Sliding window support for `AiterFlashAttentionBackend` (#29234)
Signed-off-by:
ganyi
<
ygan@amd.com
>
parent
64bc09ba
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
224 additions
and
49 deletions
+224
-49
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+224
-49
No files found.
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
8c363ed6
...
...
@@ -13,8 +13,9 @@ from vllm.attention.backends.abstract import (
AttentionType
,
MultipleOf
,
)
from
vllm.attention.layer
import
Attention
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
...
...
@@ -57,58 +58,55 @@ if current_platform.is_rocm():
head_size
,
x
,
max_block_num
,
num_tokens
,
num_programs
,
DEQUANT
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
CACHE_FORMAT
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
b
id
=
tl
.
program_id
(
0
)
token_
id
=
tl
.
program_id
(
0
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
if
DEQUANT
:
k_scale
=
tl
.
load
(
k_scale_ptr
)
v_scale
=
tl
.
load
(
v_scale_ptr
)
for
token_id
in
tl
.
range
(
bid
,
num_tokens
,
num_programs
):
key_ptr_offset
=
key_ptr
+
token_id
*
head_size
*
num_heads
value_ptr_offset
=
value_ptr
+
token_id
*
head_size
*
num_heads
batch_idx
=
tl
.
load
(
token_to_batch_ptr
+
token_id
)
batch_start
=
tl
.
load
(
seq_start_ptr
+
batch_idx
)
token_start
=
tl
.
load
(
cu_seqlens_kv_ptr
+
batch_idx
)
batch_offset
=
token_id
-
token_start
+
batch_start
block_offset
=
batch_offset
//
PAGE_SIZE
block_id
=
tl
.
load
(
block_table_ptr
+
max_block_num
*
batch_idx
+
block_offset
key_ptr_offset
=
key_ptr
+
token_id
*
head_size
*
num_heads
value_ptr_offset
=
value_ptr
+
token_id
*
head_size
*
num_heads
batch_idx
=
tl
.
load
(
token_to_batch_ptr
+
token_id
)
batch_start
=
tl
.
load
(
seq_start_ptr
+
batch_idx
)
token_start
=
tl
.
load
(
cu_seqlens_kv_ptr
+
batch_idx
)
batch_offset
=
token_id
-
token_start
+
batch_start
block_offset
=
batch_offset
//
PAGE_SIZE
block_id
=
tl
.
load
(
block_table_ptr
+
max_block_num
*
batch_idx
+
block_offset
).
to
(
tl
.
int64
)
slot_id
=
batch_offset
%
PAGE_SIZE
if
CACHE_FORMAT
==
"NHD"
:
# for kv cache layout as
# K: [num_blocks, page_size, num_head, head_dim]
# V: [num_blocks, page_size, num_head, head_dim]
key_cache_ptr_offset
=
(
key_cache_ptr
+
block_id
*
num_heads
*
head_size
*
PAGE_SIZE
+
slot_id
*
num_heads
*
head_size
)
value_cache_ptr_offset
=
(
value_cache_ptr
+
block_id
*
num_heads
*
head_size
*
PAGE_SIZE
+
slot_id
*
num_heads
*
head_size
)
slot_id
=
batch_offset
%
PAGE_SIZE
if
CACHE_FORMAT
==
"NHD"
:
# for kv cache layout as
# K: [num_blocks, page_size, num_head, head_dim]
# V: [num_blocks, page_size, num_head, head_dim]
key_cache_ptr_offset
=
(
key_cache_ptr
+
block_id
*
num_heads
*
head_size
*
PAGE_SIZE
+
slot_id
*
num_heads
*
head_size
)
value_cache_ptr_offset
=
(
value_cache_ptr
+
block_id
*
num_heads
*
head_size
*
PAGE_SIZE
+
slot_id
*
num_heads
*
head_size
)
for
i
in
tl
.
range
(
0
,
head_size
*
num_heads
,
BLOCK_SIZE
):
mask
=
(
col_offsets
+
i
)
<
head_size
*
num_heads
k_reg
=
tl
.
load
(
key_cache_ptr_offset
+
col_offsets
+
i
,
mask
=
mask
)
v_reg
=
tl
.
load
(
value_cache_ptr_offset
+
col_offsets
+
i
,
mask
=
mask
)
if
DEQUANT
:
k_dtype
=
k_reg
.
dtype
v_dtype
=
v_reg
.
dtype
k_reg
=
(
k_reg
.
to
(
tl
.
float32
)
*
k_scale
).
to
(
k_dtype
)
v_reg
=
(
v_reg
.
to
(
tl
.
float32
)
*
v_scale
).
to
(
v_dtype
)
tl
.
store
(
key_ptr_offset
+
col_offsets
+
i
,
k_reg
,
mask
=
mask
)
tl
.
store
(
value_ptr_offset
+
col_offsets
+
i
,
v_reg
,
mask
=
mask
)
for
i
in
tl
.
range
(
0
,
head_size
*
num_heads
,
BLOCK_SIZE
):
mask
=
(
col_offsets
+
i
)
<
head_size
*
num_heads
k_reg
=
tl
.
load
(
key_cache_ptr_offset
+
col_offsets
+
i
,
mask
=
mask
)
v_reg
=
tl
.
load
(
value_cache_ptr_offset
+
col_offsets
+
i
,
mask
=
mask
)
if
DEQUANT
:
k_dtype
=
k_reg
.
dtype
v_dtype
=
v_reg
.
dtype
k_reg
=
(
k_reg
.
to
(
tl
.
float32
)
*
k_scale
).
to
(
k_dtype
)
v_reg
=
(
v_reg
.
to
(
tl
.
float32
)
*
v_scale
).
to
(
v_dtype
)
tl
.
store
(
key_ptr_offset
+
col_offsets
+
i
,
k_reg
,
mask
=
mask
)
tl
.
store
(
value_ptr_offset
+
col_offsets
+
i
,
v_reg
,
mask
=
mask
)
def
cp_mha_gather_cache
(
key_cache
:
torch
.
Tensor
,
...
...
@@ -143,9 +141,7 @@ if current_platform.is_rocm():
page_size
=
key_cache
.
shape
[
1
]
num_heads
=
key_cache
.
shape
[
2
]
NUM_PRGMS
=
num_programs
(
total_tokens
)
BLOCK_SIZE
=
block_size
(
key_cache
,
head_dim
)
grid
=
lambda
meta
:
(
NUM_PRGMS
,)
grid
=
lambda
meta
:
(
total_tokens
,)
cp_mha_gather_cache_kernel
[
grid
](
key_cache
,
value_cache
,
...
...
@@ -161,12 +157,10 @@ if current_platform.is_rocm():
head_dim
,
x
,
block_tables
.
size
(
1
),
total_tokens
,
NUM_PRGMS
,
DEQUANT
=
dequant
,
PAGE_SIZE
=
page_size
,
CACHE_FORMAT
=
kv_cache_layout
,
BLOCK_SIZE
=
BLOCK_SIZE
,
BLOCK_SIZE
=
head_dim
,
)
...
...
@@ -189,6 +183,17 @@ class AiterFlashAttentionPrefillMetadata:
query_start_loc
:
torch
.
Tensor
@
dataclass
class
AiterChunkSlidingWindowMetadata
:
swa_seqlens
:
torch
.
Tensor
swa_cu_seqlens
:
torch
.
Tensor
swa_seq_starts
:
torch
.
Tensor
swa_token_to_batch
:
torch
.
Tensor
swa_max_seqlens
:
int
swa_total_tokens
:
int
swa_workspace
:
torch
.
Tensor
@
dataclass
class
AiterChunkContextMetadata
:
workspace
:
torch
.
Tensor
...
...
@@ -200,6 +205,7 @@ class AiterChunkContextMetadata:
seq_lens
:
torch
.
Tensor
num_chunks
:
int
total_token_per_batch
:
list
[
int
]
swa_metadata
:
AiterChunkSlidingWindowMetadata
|
None
@
dataclass
...
...
@@ -278,6 +284,20 @@ class AiterFlashAttentionMetadataBuilder(
self
.
aot_sliding_window
:
tuple
[
int
,
int
]
|
None
=
None
self
.
total_tokens
:
int
=
0
sliding_window_configs
:
set
[
tuple
[
int
,
int
]
|
None
]
=
set
()
layers
=
get_layers_from_vllm_config
(
self
.
vllm_config
,
Attention
)
for
layer
in
layers
.
values
():
assert
isinstance
(
layer
.
impl
,
AiterFlashAttentionImpl
)
sliding_window_configs
.
add
(
layer
.
impl
.
sliding_window
)
while
len
(
sliding_window_configs
)
>
0
:
sliding_window_config
=
sliding_window_configs
.
pop
()
if
sliding_window_config
is
not
None
and
sliding_window_config
[
0
]
!=
-
1
:
assert
self
.
aot_sliding_window
is
None
,
(
"Aiter Flash ATTENTION can only support one valid sliding window!"
)
self
.
aot_sliding_window
=
sliding_window_config
self
.
extend_workspace
=
torch
.
empty
(
[
2
,
_CP_TOKENS_PER_ITER_ROCM
,
self
.
num_heads_kv
,
self
.
headdim
],
dtype
=
self
.
model_config
.
dtype
,
...
...
@@ -349,6 +369,55 @@ class AiterFlashAttentionMetadataBuilder(
query_lens_for_extend
=
query_lens_cpu
[
num_extends_slice
]
seq_lens_for_extend
=
common_attn_metadata
.
seq_lens_cpu
[
num_extends_slice
]
computed_kv_lens
=
seq_lens_for_extend
-
query_lens_for_extend
swa_metadata
=
None
if
self
.
aot_sliding_window
is
not
None
:
swa_seqlen_for_extend
=
torch
.
minimum
(
seq_lens_for_extend
,
query_lens_for_extend
+
self
.
aot_sliding_window
[
0
]
+
1
,
)
cu_seq_lens
=
torch
.
zeros
(
num_extends
+
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens_for_extend
.
device
,
)
torch
.
cumsum
(
swa_seqlen_for_extend
,
dim
=
0
,
dtype
=
cu_seq_lens
.
dtype
,
out
=
cu_seq_lens
[
1
:],
)
token_to_seq
=
torch
.
arange
(
0
,
num_extends
,
dtype
=
torch
.
int32
,
device
=
seq_lens_for_extend
.
device
,
)
token_to_seq
=
torch
.
repeat_interleave
(
token_to_seq
,
swa_seqlen_for_extend
)
fetched_shape
=
cu_seq_lens
[
-
1
].
item
()
# TODO(ganyi): Maybe reuse these 2 buffer from extend_workspace
swa_workspace
=
torch
.
empty
(
(
2
,
fetched_shape
,
self
.
num_heads_kv
,
self
.
headdim
),
dtype
=
self
.
vllm_config
.
model_config
.
dtype
,
device
=
self
.
device
,
)
seq_starts
=
seq_lens_for_extend
-
swa_seqlen_for_extend
max_seqlen_k
=
swa_seqlen_for_extend
.
max
().
item
()
total_tokens
=
cu_seq_lens
[
-
1
].
item
()
swa_metadata
=
AiterChunkSlidingWindowMetadata
(
swa_seqlens
=
swa_seqlen_for_extend
.
to
(
self
.
device
,
non_blocking
=
True
),
swa_cu_seqlens
=
cu_seq_lens
.
to
(
self
.
device
,
non_blocking
=
True
),
swa_seq_starts
=
seq_starts
.
to
(
self
.
device
,
non_blocking
=
True
),
swa_token_to_batch
=
token_to_seq
.
to
(
self
.
device
,
non_blocking
=
True
),
swa_max_seqlens
=
max_seqlen_k
,
swa_total_tokens
=
total_tokens
,
swa_workspace
=
swa_workspace
,
)
# allocate the equal amount of workspace for
# each chunk prefill request
...
...
@@ -392,6 +461,7 @@ class AiterFlashAttentionMetadataBuilder(
token_to_batch
=
token_to_batch_tensor
.
to
(
self
.
device
,
non_blocking
=
True
),
num_chunks
=
num_chunks
,
total_token_per_batch
=
cu_seq_lens_cpu
[:,
-
1
].
tolist
(),
swa_metadata
=
swa_metadata
,
)
query_start_loc_device
=
common_attn_metadata
.
query_start_loc
[
...
...
@@ -504,9 +574,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
alibi_slopes
=
torch
.
tensor
(
alibi_slopes
,
dtype
=
torch
.
float32
)
self
.
alibi_slopes
=
alibi_slopes
if
sliding_window
is
None
:
self
.
sliding_window
=
[
-
1
,
-
1
]
self
.
sliding_window
=
(
-
1
,
-
1
)
else
:
self
.
sliding_window
=
[
sliding_window
-
1
,
0
]
self
.
sliding_window
=
(
sliding_window
-
1
,
0
)
self
.
kv_cache_dtype
=
kv_cache_dtype
if
logits_soft_cap
is
None
:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
...
...
@@ -522,6 +592,67 @@ class AiterFlashAttentionImpl(AttentionImpl):
"Encoder self-attention is not implemented for FlashAttentionImpl"
)
def
extend_for_sliding_window
(
self
,
attn_metadata
:
AiterFlashAttentionMetadata
,
query
:
torch
.
Tensor
,
key_cache
,
value_cache
,
output
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
block_table
:
torch
.
Tensor
,
k_scale
:
float
,
v_scale
:
float
,
):
assert
attn_metadata
.
extend_metadata
is
not
None
assert
attn_metadata
.
extend_metadata
.
chunk_context_metadata
is
not
None
chunked_metadata
=
attn_metadata
.
extend_metadata
.
chunk_context_metadata
swa_metadata
=
chunked_metadata
.
swa_metadata
assert
swa_metadata
is
not
None
swa_cu_seqlens
=
swa_metadata
.
swa_cu_seqlens
swa_seq_starts
=
swa_metadata
.
swa_seq_starts
swa_token_to_batch
=
swa_metadata
.
swa_token_to_batch
swa_max_seqlens
=
swa_metadata
.
swa_max_seqlens
swa_total_tokens
=
swa_metadata
.
swa_total_tokens
key_fetched
,
value_fetched
=
(
swa_metadata
.
swa_workspace
[
0
],
swa_metadata
.
swa_workspace
[
1
],
)
cp_mha_gather_cache
(
key_cache
=
key_cache
,
value_cache
=
value_cache
,
key
=
key_fetched
,
value
=
value_fetched
,
block_tables
=
block_table
,
k_scales
=
k_scale
,
v_scales
=
v_scale
,
cu_seqlens_kv
=
swa_cu_seqlens
,
token_to_batch
=
swa_token_to_batch
,
seq_starts
=
swa_seq_starts
,
dequant
=
False
,
kv_cache_layout
=
"NHD"
,
total_tokens
=
swa_total_tokens
,
)
aiter
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key_fetched
,
v
=
value_fetched
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
swa_cu_seqlens
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
swa_max_seqlens
,
min_seqlen_q
=
1
,
dropout_p
=
0.0
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
return_lse
=
False
,
out
=
output
,
)
def
extend_forward
(
self
,
attn_metadata
:
AiterFlashAttentionMetadata
,
...
...
@@ -540,6 +671,20 @@ class AiterFlashAttentionImpl(AttentionImpl):
k_scale
:
float
,
v_scale
:
float
,
):
if
self
.
sliding_window
[
0
]
!=
-
1
:
self
.
extend_for_sliding_window
(
attn_metadata
,
query
,
key_cache
,
value_cache
,
output
,
cu_seqlens_q
,
max_seqlen_q
,
block_table
,
k_scale
,
v_scale
,
)
return
out
,
lse
=
aiter
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
...
...
@@ -782,6 +927,36 @@ class AiterFlashAttentionImpl(AttentionImpl):
# calculate for decodes
if
num_decodes
>
0
:
assert
attn_metadata
.
decode_metadata
is
not
None
if
self
.
sliding_window
[
0
]
!=
-
1
:
from
aiter.ops.triton.unified_attention
import
(
unified_attention
,
)
descale_shape
=
(
attn_metadata
.
query_start_loc
[:
num_decodes
].
shape
[
0
]
-
1
,
key_cache
.
shape
[
2
],
)
unified_attention
(
q
=
query
[:
num_decode_tokens
],
k
=
key_cache
,
v
=
value_cache
,
out
=
output
[:
num_decode_tokens
],
cu_seqlens_q
=
attn_metadata
.
query_start_loc
[:
num_decodes
],
max_seqlen_q
=
1
,
# optimize this
seqused_k
=
attn_metadata
.
seq_lens
[:
num_decodes
],
max_seqlen_k
=
attn_metadata
.
max_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
block_table
=
attn_metadata
.
block_table
[:
num_decodes
],
softcap
=
self
.
logits_soft_cap
,
q_descale
=
None
,
k_descale
=
layer
.
_k_scale
.
expand
(
descale_shape
),
v_descale
=
layer
.
_v_scale
.
expand
(
descale_shape
),
)
return
assert
attn_metadata
.
decode_metadata
is
not
None
_
,
num_heads
,
head_size
=
query
.
shape
nbytes_per_qo_elem
=
torch
.
finfo
(
query
.
dtype
).
bits
//
8
num_seqs
=
attn_metadata
.
seq_lens
.
shape
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment