Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dc937175
Unverified
Commit
dc937175
authored
Nov 05, 2025
by
Pleaplusone
Committed by
GitHub
Nov 04, 2025
Browse files
[ROCm][Perf] New design on ROCm AITER MHA backend Implementation (#25763)
Signed-off-by:
ganyi
<
ygan@amd.com
>
parent
2f1cc8ce
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
593 additions
and
275 deletions
+593
-275
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+526
-275
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+67
-0
No files found.
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
dc937175
...
@@ -13,221 +13,202 @@ from vllm.attention.backends.abstract import (
...
@@ -13,221 +13,202 @@ from vllm.attention.backends.abstract import (
AttentionType
,
AttentionType
,
MultipleOf
,
MultipleOf
,
)
)
from
vllm.attention.ops.merge_attn_states
import
merge_attn_states
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
AttentionCGSupport
,
AttentionCGSupport
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
split_decodes_prefills_and_extends
,
)
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
_PARTITION_SIZE_ROCM
=
256
_PARTITION_SIZE_ROCM
=
256
_CP_TOKENS_PER_ITER_ROCM
=
32
*
1024
if
current_platform
.
is_rocm
():
if
current_platform
.
is_rocm
():
import
aiter
import
aiter
from
aiter.ops.triton.utils.device_info
import
get_num_sms
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils.torch_utils
import
direct_register_custom_op
def
block_size
(
x
,
head_dim
):
return
min
(
65536
//
x
.
element_size
(),
triton
.
next_power_of_2
(
head_dim
))
def
num_programs
(
head_dim
):
return
min
(
head_dim
,
get_num_sms
())
@
triton
.
jit
@
triton
.
jit
def
_vllm_layout_trans_kernel
(
def
cp_mha_gather_cache_kernel
(
k_buffer_ptr
,
key_cache_ptr
,
# [num_blocks, page_size, num_head, head_size]
v_buffer_ptr
,
value_cache_ptr
,
# [num_blocks, page_size, num_head, head_size]
k_values_ptr
,
key_ptr
,
# [num_tokens, num_heads, head_size]
v_values_ptr
,
value_ptr
,
# [num_tokens, num_heads, head_size]
b_query_lens_loc
,
block_table_ptr
,
# [num_batches, max_block_num]
b_seq_lens_loc
,
cu_seqlens_kv_ptr
,
# [num_batches + 1]
block_table
,
token_to_batch_ptr
,
# [max_cum_tokens]
block_table_stride_0
,
seq_start_ptr
,
# [num_batches]
k_scale
,
k_scale_ptr
,
v_scale
,
v_scale_ptr
,
output_dtype
:
tl
.
constexpr
,
num_heads
,
E_DIM
:
tl
.
constexpr
,
head_size
,
x
,
max_block_num
,
num_tokens
,
DEQUANT
:
tl
.
constexpr
,
PAGE_SIZE
:
tl
.
constexpr
,
CACHE_FORMAT
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
NUM_PRGMS
:
tl
.
constexpr
,
):
):
batch_idx
=
tl
.
program_id
(
0
)
bid
=
tl
.
program_id
(
0
)
block_idx
=
tl
.
program_id
(
1
)
col_offsets
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
if
DEQUANT
:
batch_query_indexes
=
tl
.
load
(
b_query_lens_loc
+
batch_idx
+
tl
.
arange
(
0
,
2
))
k_scale
=
tl
.
load
(
k_scale_ptr
)
batch_query_start
,
batch_query_end
=
tl
.
split
(
batch_query_indexes
)
v_scale
=
tl
.
load
(
v_scale_ptr
)
query_len
=
batch_query_end
-
batch_query_start
for
token_id
in
tl
.
range
(
bid
,
num_tokens
,
NUM_PRGMS
):
if
query_len
<=
1
:
key_ptr_offset
=
key_ptr
+
token_id
*
head_size
*
num_heads
return
value_ptr_offset
=
value_ptr
+
token_id
*
head_size
*
num_heads
batch_idx
=
tl
.
load
(
token_to_batch_ptr
+
token_id
)
batch_token_indexes
=
tl
.
load
(
b_seq_lens_loc
+
batch_idx
+
tl
.
arange
(
0
,
2
))
batch_start
=
tl
.
load
(
seq_start_ptr
+
batch_idx
)
batch_token_start
,
batch_token_end
=
tl
.
split
(
batch_token_indexes
)
token_start
=
tl
.
load
(
cu_seqlens_kv_ptr
+
batch_idx
)
seq_len
=
batch_token_end
-
batch_token_start
batch_offset
=
token_id
-
token_start
+
batch_start
block_offset
=
batch_offset
//
PAGE_SIZE
if
block_idx
*
BLOCK_SIZE
<
seq_len
:
block_id
=
tl
.
load
(
block_mask
=
(
block_table_ptr
+
max_block_num
*
batch_idx
+
block_offset
block_idx
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)[:,
None
]
)
)
<
seq_len
slot_id
=
batch_offset
%
PAGE_SIZE
kv_idx
=
tl
.
load
(
if
CACHE_FORMAT
==
"NHD"
:
block_table
+
batch_idx
*
block_table_stride_0
+
block_idx
# for kv cache layout as
).
to
(
tl
.
int64
)
# K: [num_blocks, page_size, num_head, head_dim]
# V: [num_blocks, page_size, num_head, head_dim]
kv_buffer_off
=
(
key_cache_ptr_offset
=
(
kv_idx
*
BLOCK_SIZE
*
E_DIM
key_cache_ptr
+
tl
.
arange
(
0
,
BLOCK_SIZE
)[:,
None
]
*
E_DIM
+
block_id
*
num_heads
*
head_size
*
PAGE_SIZE
+
tl
.
arange
(
0
,
E_DIM
)[
None
,
:]
+
slot_id
*
num_heads
*
head_size
)
value_cache_ptr_offset
=
(
value_cache_ptr
+
block_id
*
num_heads
*
head_size
*
PAGE_SIZE
+
slot_id
*
num_heads
*
head_size
)
)
k_vals
=
tl
.
load
(
k_buffer_ptr
+
kv_buffer_off
,
mask
=
block_mask
,
other
=
0.0
)
if
k_vals
.
dtype
.
is_fp8
():
k_vals
=
(
k_vals
.
to
(
tl
.
float32
)
*
tl
.
load
(
k_scale
)).
to
(
output_dtype
)
else
:
k_vals
=
k_vals
.
to
(
output_dtype
)
v_vals
=
tl
.
load
(
v_buffer_ptr
+
kv_buffer_off
,
mask
=
block_mask
,
other
=
0.0
)
for
i
in
tl
.
range
(
0
,
head_size
*
num_heads
,
BLOCK_SIZE
):
if
v_vals
.
dtype
.
is_fp8
():
mask
=
(
col_offsets
+
i
)
<
head_size
*
num_heads
v_vals
=
(
v_vals
.
to
(
tl
.
float32
)
*
tl
.
load
(
v_scale
)).
to
(
output_dtype
)
k_reg
=
tl
.
load
(
key_cache_ptr_offset
+
col_offsets
+
i
,
mask
=
mask
)
else
:
v_reg
=
tl
.
load
(
value_cache_ptr_offset
+
col_offsets
+
i
,
mask
=
mask
)
v_vals
=
v_vals
.
to
(
output_dtype
)
if
DEQUANT
:
kv_values_off
=
(
k_dtype
=
k_reg
.
dtype
batch_token_start
*
E_DIM
v_dtype
=
v_reg
.
dtype
+
block_idx
*
BLOCK_SIZE
*
E_DIM
k_reg
=
(
k_reg
.
to
(
tl
.
float32
)
*
k_scale
).
to
(
k_dtype
)
+
tl
.
arange
(
0
,
BLOCK_SIZE
)[:,
None
]
*
E_DIM
v_reg
=
(
v_reg
.
to
(
tl
.
float32
)
*
v_scale
).
to
(
v_dtype
)
+
tl
.
arange
(
0
,
E_DIM
)[
None
,
:]
tl
.
store
(
key_ptr_offset
+
col_offsets
+
i
,
k_reg
,
mask
=
mask
)
)
tl
.
store
(
value_ptr_offset
+
col_offsets
+
i
,
v_reg
,
mask
=
mask
)
tl
.
store
(
k_values_ptr
+
kv_values_off
,
k_vals
,
mask
=
block_mask
)
tl
.
store
(
v_values_ptr
+
kv_values_off
,
v_vals
,
mask
=
block_mask
)
def
cp_mha_gather_cache
(
key_cache
:
torch
.
Tensor
,
def
vllm_layout_trans
(
value_cache
:
torch
.
Tensor
,
b_query_lens_loc
,
key
:
torch
.
Tensor
,
b_seq_lens_loc
,
value
:
torch
.
Tensor
,
block_table
,
block_tables
:
torch
.
Tensor
,
k_cache
,
k_scales
:
torch
.
Tensor
,
v_cache
,
v_scales
:
torch
.
Tensor
,
max_seq_len
,
cu_seqlens_kv
:
torch
.
Tensor
,
k_scale
,
token_to_batch
:
torch
.
Tensor
,
v_scale
,
seq_starts
:
torch
.
Tensor
,
output_dtype
,
dequant
:
bool
,
total_tokens
,
kv_cache_layout
:
str
,
total_tokens
:
int
,
):
):
H_KV
=
v_cache
.
shape
[
2
]
assert
kv_cache_layout
in
[
"v0"
,
"NHD"
,
"HND"
],
(
D
=
v_cache
.
shape
[
3
]
"kv_cache_layout only support v0, NHD, HND"
BLOCK_SIZE
=
v_cache
.
shape
[
1
]
k_values
=
torch
.
empty
(
(
total_tokens
,
H_KV
,
D
),
dtype
=
output_dtype
,
device
=
k_cache
.
device
,
)
)
v_values
=
torch
.
empty
(
head_dim
=
key
.
shape
[
2
]
(
total_tokens
,
H_KV
,
D
),
x
=
0
dtype
=
output_dtype
,
# assert dequant is True, "Currently, we only support "\
device
=
v_cache
.
device
,
# "gather cache with dequant"
# For k cache layout: [num_blocks, num_heads, page_size, head_dim]
assert
kv_cache_layout
==
"NHD"
,
(
"ROCM_AITER_FA_BACKEND Only support NHD kv cache layout for now"
)
)
assert
head_dim
==
key_cache
.
shape
[
3
],
(
"We assume your kv cache layout is [num_blocks, "
"page_size, num_heads, head_dim], but got otherwise"
)
page_size
=
key_cache
.
shape
[
1
]
num_heads
=
key_cache
.
shape
[
2
]
grid
=
(
block_table
.
shape
[
0
],
(
max_seq_len
+
BLOCK_SIZE
-
1
)
//
BLOCK_SIZE
)
NUM_PRGMS
=
num_programs
(
total_tokens
)
BLOCK_SIZE
=
block_size
(
key_cache
,
head_dim
)
if
output_dtype
==
torch
.
float16
:
grid
=
lambda
meta
:
(
NUM_PRGMS
,)
output_dtype
=
tl
.
float16
cp_mha_gather_cache_kernel
[
grid
](
elif
output_dtype
==
torch
.
bfloat16
:
key_cache
,
output_dtype
=
tl
.
bfloat16
value_cache
,
else
:
key
,
raise
ValueError
(
f
"Unsupported output dtype:
{
output_dtype
}
"
)
value
,
block_tables
,
_vllm_layout_trans_kernel
[
grid
](
cu_seqlens_kv
,
k_ca
ch
e
,
token_to_bat
ch
,
v_cache
,
seq_starts
,
k_
v
al
u
es
,
k_
sc
ales
,
v_
v
al
u
es
,
v_
sc
ales
,
b_query_lens_loc
,
num_heads
,
b_seq_lens_loc
,
head_dim
,
block_table
,
x
,
block_table
.
s
trid
e
(
0
),
block_table
s
.
s
iz
e
(
1
),
k_scale
,
total_tokens
,
v_scale
,
DEQUANT
=
dequant
,
output_dtype
=
output_dtyp
e
,
PAGE_SIZE
=
page_siz
e
,
E_DIM
=
H_KV
*
D
,
CACHE_FORMAT
=
kv_cache_layout
,
BLOCK_SIZE
=
BLOCK_SIZE
,
BLOCK_SIZE
=
BLOCK_SIZE
,
NUM_PRGMS
=
NUM_PRGMS
,
)
)
return
k_values
,
v_values
def
flash_attn_varlen_func_impl
(
logger
=
init_logger
(
__name__
)
q
:
torch
.
Tensor
,
k_cache
:
torch
.
Tensor
,
v_cache
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
softmax_scale
:
float
,
window_size
:
list
[
int
]
|
None
,
# -1 means infinite context window
alibi_slopes
:
list
[
float
]
|
None
,
block_table
:
torch
.
Tensor
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
total_tokens
:
int
=
0
,
)
->
torch
.
Tensor
:
if
total_tokens
==
0
:
total_tokens
=
int
(
cu_seqlens_k
[
-
1
].
item
())
k
,
v
=
vllm_layout_trans
(
cu_seqlens_q
,
cu_seqlens_k
,
block_table
,
k_cache
,
v_cache
,
max_seqlen_k
,
k_scale
,
v_scale
,
q
.
dtype
,
total_tokens
,
)
output
=
aiter
.
flash_attn_varlen_func
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
min_seqlen_q
=
1
,
cu_seqlens_k
=
cu_seqlens_k
,
max_seqlen_k
=
max_seqlen_k
,
softmax_scale
=
softmax_scale
,
causal
=
True
,
alibi_slopes
=
alibi_slopes
,
window_size
=
window_size
,
out
=
out
,
)
return
output
def
flash_attn_varlen_func_fake
(
@
dataclass
q
:
torch
.
Tensor
,
class
AiterFlashAttentionDecodeMetadata
:
k_cache
:
torch
.
Tensor
,
max_query_len
:
int
v_cache
:
torch
.
Tensor
,
min_query_len
:
int
out
:
torch
.
Tensor
,
max_seq_len
:
int
cu_seqlens_q
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
cu_seqlens_k
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
softmax_scale
:
float
,
window_size
:
list
[
int
]
|
None
,
# -1 means infinite context window
alibi_slopes
:
list
[
float
]
|
None
,
block_table
:
torch
.
Tensor
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
total_tokens
:
int
=
0
,
)
->
torch
.
Tensor
:
return
torch
.
empty
(
q
.
shape
[
0
],
q
.
shape
[
1
],
v_cache
.
shape
[
-
2
],
dtype
=
q
.
dtype
,
device
=
q
.
device
)
direct_register_custom_op
(
"flash_attn_varlen_func"
,
flash_attn_varlen_func_impl
,
[
"out"
],
flash_attn_varlen_func_fake
,
dispatch_key
=
current_platform
.
dispatch_key
,
)
logger
=
init_logger
(
__name__
)
@
dataclass
class
AiterFlashAttentionPrefillMetadata
:
max_query_len
:
int
min_query_len
:
int
max_seq_len
:
int
query_start_loc
:
torch
.
Tensor
@
dataclass
class
AiterChunkContextMetadata
:
workspace
:
torch
.
Tensor
cu_seq_lens_chunk
:
torch
.
Tensor
chunk_starts
:
torch
.
Tensor
token_to_batch
:
torch
.
Tensor
seq_tot
:
list
[
int
]
max_seq_lens
:
list
[
int
]
seq_lens
:
torch
.
Tensor
num_chunks
:
int
total_token_per_batch
:
list
[
int
]
@
dataclass
class
AiterFlashAttentionChunkPrefillMetadata
:
max_query_len
:
int
min_query_len
:
int
max_seq_len
:
int
query_start_loc
:
torch
.
Tensor
chunk_context_metadata
:
AiterChunkContextMetadata
@
dataclass
@
dataclass
...
@@ -248,7 +229,18 @@ class AiterFlashAttentionMetadata:
...
@@ -248,7 +229,18 @@ class AiterFlashAttentionMetadata:
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
block_table
:
torch
.
Tensor
block_table
:
torch
.
Tensor
cu_seq_lens
:
torch
.
Tensor
|
None
# prefill and deocde split
num_decodes
:
int
num_decode_tokens
:
int
num_prefills
:
int
num_prefill_tokens
:
int
num_extends
:
int
num_extend_tokens
:
int
decode_metadata
:
AiterFlashAttentionDecodeMetadata
|
None
prefill_metadata
:
AiterFlashAttentionPrefillMetadata
|
None
extend_metadata
:
AiterFlashAttentionChunkPrefillMetadata
|
None
# For cascade attention.
# For cascade attention.
use_cascade
:
bool
use_cascade
:
bool
...
@@ -260,6 +252,7 @@ class AiterFlashAttentionMetadataBuilder(
...
@@ -260,6 +252,7 @@ class AiterFlashAttentionMetadataBuilder(
AttentionMetadataBuilder
[
AiterFlashAttentionMetadata
]
AttentionMetadataBuilder
[
AiterFlashAttentionMetadata
]
):
):
cudagraph_support
=
AttentionCGSupport
.
UNIFORM_SINGLE_TOKEN_DECODE
cudagraph_support
=
AttentionCGSupport
.
UNIFORM_SINGLE_TOKEN_DECODE
reorder_batch_threshold
:
int
=
1
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -285,6 +278,12 @@ class AiterFlashAttentionMetadataBuilder(
...
@@ -285,6 +278,12 @@ class AiterFlashAttentionMetadataBuilder(
self
.
aot_sliding_window
:
tuple
[
int
,
int
]
|
None
=
None
self
.
aot_sliding_window
:
tuple
[
int
,
int
]
|
None
=
None
self
.
total_tokens
:
int
=
0
self
.
total_tokens
:
int
=
0
self
.
extend_workspace
=
torch
.
empty
(
[
2
,
_CP_TOKENS_PER_ITER_ROCM
,
self
.
num_heads_kv
,
self
.
headdim
],
dtype
=
self
.
model_config
.
dtype
,
device
=
device
,
)
def
build_for_cudagraph_capture
(
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
self
,
common_attn_metadata
:
CommonAttentionMetadata
):
):
...
@@ -302,42 +301,139 @@ class AiterFlashAttentionMetadataBuilder(
...
@@ -302,42 +301,139 @@ class AiterFlashAttentionMetadataBuilder(
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
fast_build
:
bool
=
False
,
fast_build
:
bool
=
False
,
)
->
"AiterFlashAttentionMetadata"
:
)
->
"AiterFlashAttentionMetadata"
:
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
split_ret
=
split_decodes_prefills_and_extends
(
max_query_len
=
common_attn_metadata
.
max_query_len
common_attn_metadata
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
decode_threshold
=
self
.
reorder_batch_threshold
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
)
seq_lens
=
common_attn_metadata
.
seq_lens
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
(
slot_mapping
=
common_attn_metadata
.
slot_mapping
num_decodes
,
if
max_query_len
>
1
:
num_extends
,
# We pre-compute cumulative seq len needed for prefill attention
num_prefills
,
# here to avoid recomputing it for every layer
num_decode_tokens
,
num_extend_tokens
,
num_prefill_tokens
,
)
=
split_ret
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
seq_lens
=
common_attn_metadata
.
seq_lens_cpu
query_lens_cpu
=
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
]
decode_metadata
=
None
if
num_decodes
>
0
:
decode_metadata
=
AiterFlashAttentionDecodeMetadata
(
max_query_len
=
query_lens_cpu
[:
num_decodes
].
max
().
item
(),
min_query_len
=
query_lens_cpu
[:
num_decodes
].
min
().
item
(),
max_seq_len
=
seq_lens
[:
num_decodes
].
max
().
item
(),
query_start_loc
=
common_attn_metadata
.
query_start_loc
[:
num_decodes
+
1
],
)
prefill_metadata
=
None
if
num_prefills
>
0
:
query_lens_for_prefill
=
query_lens_cpu
[
num_decodes
+
num_extends
:]
query_start_loc_device
=
common_attn_metadata
.
query_start_loc
[
num_decodes
+
num_extends
:
]
prefill_metadata
=
AiterFlashAttentionPrefillMetadata
(
max_query_len
=
query_lens_for_prefill
.
max
().
item
(),
min_query_len
=
query_lens_for_prefill
.
min
().
item
(),
max_seq_len
=
seq_lens
[
num_decodes
+
num_extends
:].
max
().
item
(),
query_start_loc
=
query_start_loc_device
-
query_start_loc_device
[
0
],
)
extend_metadata
=
None
if
num_extends
>
0
:
num_extends_slice
=
slice
(
num_decodes
,
num_decodes
+
num_extends
)
query_lens_for_extend
=
query_lens_cpu
[
num_extends_slice
]
seq_lens_for_extend
=
common_attn_metadata
.
seq_lens_cpu
[
num_extends_slice
]
computed_kv_lens
=
seq_lens_for_extend
-
query_lens_for_extend
# allocate the equal amount of workspace for
# each chunk prefill request
max_context_chunk
=
_CP_TOKENS_PER_ITER_ROCM
//
num_extends
num_chunks
=
cdiv
(
computed_kv_lens
.
max
().
item
(),
max_context_chunk
)
chunk_starts
=
(
torch
.
arange
(
num_chunks
,
dtype
=
torch
.
int32
)
.
unsqueeze
(
1
)
.
expand
(
-
1
,
num_extends
)
*
max_context_chunk
)
chunk_ends
=
torch
.
min
(
computed_kv_lens
.
unsqueeze
(
0
),
chunk_starts
+
max_context_chunk
)
chunk_seq_lens
=
(
chunk_ends
-
chunk_starts
).
clamp
(
min
=
0
)
# [num_chunks, num_extends]
cu_seq_lens_cpu
=
torch
.
zeros
(
[
num_chunks
,
num_extends
+
1
],
dtype
=
torch
.
int32
,
pin_memory
=
True
)
torch
.
cumsum
(
chunk_seq_lens
,
dim
=
1
,
out
=
cu_seq_lens_cpu
[:,
1
:],
dtype
=
torch
.
int32
)
max_cum_tokens
=
cu_seq_lens_cpu
[:,
-
1
].
max
().
item
()
range_idx
=
torch
.
arange
(
max_cum_tokens
,
dtype
=
torch
.
int32
)[
None
,
None
,
:]
idx_to_batch_tensor
=
range_idx
==
cu_seq_lens_cpu
[:,
1
:][:,
:,
None
]
idx_to_batch_tensor
=
idx_to_batch_tensor
.
sum
(
dim
=
1
)
# [num_chunks, max_cum_tokens]
token_to_batch_tensor
=
torch
.
cumsum
(
idx_to_batch_tensor
,
dim
=
1
)
chunk_context_metadata
=
AiterChunkContextMetadata
(
workspace
=
self
.
extend_workspace
,
cu_seq_lens_chunk
=
cu_seq_lens_cpu
.
to
(
self
.
device
,
non_blocking
=
True
),
chunk_starts
=
chunk_starts
.
to
(
self
.
device
,
non_blocking
=
True
),
seq_tot
=
chunk_seq_lens
.
sum
(
dim
=
1
).
tolist
(),
max_seq_lens
=
chunk_seq_lens
.
max
(
dim
=
1
).
values
.
tolist
(),
seq_lens
=
chunk_seq_lens
,
token_to_batch
=
token_to_batch_tensor
.
to
(
self
.
device
,
non_blocking
=
True
),
num_chunks
=
num_chunks
,
total_token_per_batch
=
cu_seq_lens_cpu
[:,
-
1
].
tolist
(),
)
query_start_loc_device
=
common_attn_metadata
.
query_start_loc
[
num_decodes
:
num_decodes
+
num_extends
+
1
]
seq_lens_device
=
common_attn_metadata
.
seq_lens
[
num_extends_slice
]
cu_seq_lens
=
torch
.
zeros
(
cu_seq_lens
=
torch
.
zeros
(
seq_lens
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens
.
device
num_extends
+
1
,
dtype
=
torch
.
int32
,
device
=
seq_lens_device
.
device
)
torch
.
cumsum
(
seq_lens_device
,
dim
=
0
,
dtype
=
cu_seq_lens
.
dtype
,
out
=
cu_seq_lens
[
1
:]
)
extend_metadata
=
AiterFlashAttentionChunkPrefillMetadata
(
max_query_len
=
query_lens_for_extend
.
max
().
item
(),
min_query_len
=
query_lens_for_extend
.
min
().
item
(),
max_seq_len
=
seq_lens
[
num_extends_slice
].
max
().
item
(),
query_start_loc
=
query_start_loc_device
-
query_start_loc_device
[
0
],
chunk_context_metadata
=
chunk_context_metadata
,
)
)
torch
.
cumsum
(
seq_lens
,
dim
=
0
,
dtype
=
cu_seq_lens
.
dtype
,
out
=
cu_seq_lens
[
1
:])
num_actual_kv_tokens
=
int
(
cu_seq_lens
[
-
1
].
item
())
else
:
cu_seq_lens
=
None
num_actual_kv_tokens
=
0
def
schedule
(
num_actual_kv_tokens
=
torch
.
sum
(
seq_lens
).
item
()
batch_size
,
cu_query_lens
,
max_query_len
,
seqlens
,
max_seq_len
,
causal
):
return
None
use_cascade
=
common_prefix_len
>
0
use_cascade
=
common_prefix_len
>
0
attn_metadata
=
AiterFlashAttentionMetadata
(
attn_metadata
=
AiterFlashAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
,
num_actual_kv_tokens
=
num_actual_kv_tokens
,
num_actual_kv_tokens
=
num_actual_kv_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
common_attn_metadata
.
max_query_len
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
common_attn_metadata
.
query_start_loc
,
max_seq_len
=
max_seq_len
,
max_seq_len
=
common_attn_metadata
.
max_seq_len
,
seq_lens
=
seq_lens
,
seq_lens
=
common_attn_metadata
.
seq_lens
,
block_table
=
block_table_tensor
,
block_table
=
common_attn_metadata
.
block_table_tensor
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
cu_seq_lens
=
cu_seq_lens
,
num_decodes
=
num_decodes
,
num_decode_tokens
=
num_decode_tokens
,
num_prefills
=
num_prefills
,
num_prefill_tokens
=
num_prefill_tokens
,
num_extends
=
num_extends
,
num_extend_tokens
=
num_extend_tokens
,
decode_metadata
=
decode_metadata
,
prefill_metadata
=
prefill_metadata
,
extend_metadata
=
extend_metadata
,
use_cascade
=
use_cascade
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
common_prefix_len
=
common_prefix_len
,
total_tokens
=
self
.
total_tokens
,
total_tokens
=
self
.
total_tokens
,
...
@@ -401,6 +497,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
...
@@ -401,6 +497,7 @@ class AiterFlashAttentionBackend(AttentionBackend):
)
->
tuple
[
int
,
...]:
)
->
tuple
[
int
,
...]:
if
block_size
%
16
!=
0
:
if
block_size
%
16
!=
0
:
raise
ValueError
(
"Block size must be a multiple of 16."
)
raise
ValueError
(
"Block size must be a multiple of 16."
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
return
(
2
,
num_blocks
,
block_size
,
num_kv_heads
,
head_size
)
...
@@ -449,6 +546,110 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -449,6 +546,110 @@ class AiterFlashAttentionImpl(AttentionImpl):
"FlashAttentionImpl"
"FlashAttentionImpl"
)
)
def
extend_forward
(
self
,
attn_metadata
:
AiterFlashAttentionMetadata
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
output
:
torch
.
Tensor
,
cu_seqlens_q
:
torch
.
Tensor
,
max_seqlen_q
:
int
,
max_seqlen_k
:
int
,
min_seqlen_q
:
int
,
block_table
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
k_scale
:
float
,
v_scale
:
float
,
):
out
,
lse
=
aiter
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_q
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlen_q
,
min_seqlen_q
=
min_seqlen_q
,
dropout_p
=
0.0
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
return_lse
=
True
,
)
assert
attn_metadata
.
extend_metadata
is
not
None
chunk_context_metadata
=
attn_metadata
.
extend_metadata
.
chunk_context_metadata
num_chunks
=
chunk_context_metadata
.
num_chunks
workspace
=
chunk_context_metadata
.
workspace
cu_seqlens_kv
=
chunk_context_metadata
.
cu_seq_lens_chunk
max_seqlens
=
chunk_context_metadata
.
max_seq_lens
chunk_starts
=
chunk_context_metadata
.
chunk_starts
token_to_batch
=
chunk_context_metadata
.
token_to_batch
total_token_per_batch
=
chunk_context_metadata
.
total_token_per_batch
key_fetched
,
value_fetched
=
workspace
[
0
],
workspace
[
1
]
chunked_output
=
None
chunked_lse
=
None
for
chunk_idx
in
range
(
num_chunks
):
cp_mha_gather_cache
(
key_cache
=
key_cache
,
value_cache
=
value_cache
,
key
=
key_fetched
,
value
=
value_fetched
,
block_tables
=
block_table
,
k_scales
=
k_scale
,
v_scales
=
v_scale
,
cu_seqlens_kv
=
cu_seqlens_kv
[
chunk_idx
],
token_to_batch
=
token_to_batch
[
chunk_idx
],
seq_starts
=
chunk_starts
[
chunk_idx
],
dequant
=
False
,
kv_cache_layout
=
"NHD"
,
total_tokens
=
total_token_per_batch
[
chunk_idx
],
)
suf_out
,
suf_lse
=
aiter
.
flash_attn_varlen_func
(
q
=
query
,
k
=
key_fetched
,
v
=
value_fetched
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_kv
[
chunk_idx
],
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
max_seqlens
[
chunk_idx
],
min_seqlen_q
=
min_seqlen_q
,
dropout_p
=
0.0
,
softmax_scale
=
self
.
scale
,
causal
=
False
,
window_size
=
self
.
sliding_window
,
alibi_slopes
=
self
.
alibi_slopes
,
return_lse
=
True
,
)
if
chunked_output
is
None
:
chunked_output
=
suf_out
chunked_lse
=
suf_lse
else
:
tmp_output
=
torch
.
empty_like
(
out
)
tmp_lse
=
torch
.
empty_like
(
lse
)
merge_attn_states
(
output
=
tmp_output
,
output_lse
=
tmp_lse
,
prefix_output
=
chunked_output
,
prefix_lse
=
chunked_lse
,
suffix_output
=
suf_out
,
suffix_lse
=
suf_lse
,
)
chunked_output
=
tmp_output
chunked_lse
=
tmp_lse
merge_attn_states
(
output
=
output
,
prefix_output
=
chunked_output
,
prefix_lse
=
chunked_lse
,
suffix_output
=
out
,
suffix_lse
=
lse
,
)
def
forward
(
def
forward
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -488,24 +689,25 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -488,24 +689,25 @@ class AiterFlashAttentionImpl(AttentionImpl):
return
output
.
fill_
(
0
)
return
output
.
fill_
(
0
)
# IMPORTANT!
# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# NOTE(woosuk): With piece-wise CUDA graphs, this method is
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
# executed in eager-mode PyTorch. Thus, we need to be careful
# in this method. For example, `view` and `slice` (or `[:n]`) operations
# about any CPU overhead in this method. For example, `view`
# are surprisingly slow even in the case they do not invoke any GPU ops.
# and `slice` (or `[:n]`) operations are surprisingly slow even
# in the case they do not invoke any GPU ops.
# Minimize the PyTorch ops in this method as much as possible.
# Minimize the PyTorch ops in this method as much as possible.
# Whenever making a change in this method, please benchmark the
# Whenever making a change in this method, please benchmark the
# performance to make sure it does not introduce any overhead.
# performance to make sure it does not introduce any overhead.
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
if
self
.
kv_sharing_target_layer_name
is
None
:
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# not padded. However, we don't need to do key[:num_actual_tokens]
# is not padded. However, we don't need to do
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# op uses the slot_mapping's shape to determine the number of
# the reshape_and_cache_flash op uses the slot_mapping's shape
# actual tokens.
# to determine the number of actual tokens.
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
(
key
,
key
,
value
,
value
,
...
@@ -521,37 +723,85 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -521,37 +723,85 @@ class AiterFlashAttentionImpl(AttentionImpl):
key_cache
=
key_cache
.
view
(
current_platform
.
fp8_dtype
())
key_cache
=
key_cache
.
view
(
current_platform
.
fp8_dtype
())
value_cache
=
value_cache
.
view
(
current_platform
.
fp8_dtype
())
value_cache
=
value_cache
.
view
(
current_platform
.
fp8_dtype
())
# decode:extend:prefill
query
=
query
[:
num_actual_tokens
]
key
=
key
[:
num_actual_tokens
]
value
=
value
[:
num_actual_tokens
]
output_actual_tokens
=
output
[:
num_actual_tokens
]
num_decodes
=
attn_metadata
.
num_decodes
num_prefills
=
attn_metadata
.
num_prefills
num_extends
=
attn_metadata
.
num_extends
num_decode_tokens
=
attn_metadata
.
num_decode_tokens
num_extend_tokens
=
attn_metadata
.
num_extend_tokens
if
not
attn_metadata
.
use_cascade
:
if
not
attn_metadata
.
use_cascade
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
# calculate for pure prefills
seqused_k
=
attn_metadata
.
seq_lens
if
num_prefills
>
0
:
max_seqlen_q
=
attn_metadata
.
max_query_len
assert
attn_metadata
.
prefill_metadata
is
not
None
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
prefill_query
=
query
[
num_decode_tokens
+
num_extend_tokens
:]
prefill_key
=
key
[
num_decode_tokens
+
num_extend_tokens
:]
if
max_seqlen_q
>
1
:
prefill_value
=
value
[
num_decode_tokens
+
num_extend_tokens
:]
torch
.
ops
.
vllm
.
flash_attn_varlen_func
(
query
[:
num_actual_tokens
],
aiter
.
flash_attn_varlen_func
(
key_cache
,
q
=
prefill_query
,
value_cache
,
k
=
prefill_key
,
out
=
output
[:
num_actual_tokens
],
v
=
prefill_value
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
attn_metadata
.
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
max_seqlen_q
,
cu_seqlens_k
=
attn_metadata
.
prefill_metadata
.
query_start_loc
,
max_seqlen_k
=
max_seqlen_k
,
max_seqlen_q
=
attn_metadata
.
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
attn_metadata
.
prefill_metadata
.
max_seq_len
,
min_seqlen_q
=
attn_metadata
.
prefill_metadata
.
min_query_len
,
dropout_p
=
0.0
,
softmax_scale
=
self
.
scale
,
softmax_scale
=
self
.
scale
,
alibi_slopes
=
self
.
alibi_slopes
,
causal
=
True
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
block_table
=
block_table
,
alibi_slopes
=
self
.
alibi_slopes
,
cu_seqlens_k
=
attn_metadata
.
cu_seq_lens
,
out
=
output_actual_tokens
[
num_decode_tokens
+
num_extend_tokens
:],
)
# calculate for extends
if
num_extends
>
0
:
assert
attn_metadata
.
extend_metadata
is
not
None
extend_tokens_slice
=
slice
(
num_decode_tokens
,
num_decode_tokens
+
num_extend_tokens
)
extend_querys
=
query
[
extend_tokens_slice
]
extend_keys
=
key
[
extend_tokens_slice
]
extend_values
=
value
[
extend_tokens_slice
]
extend_outputs
=
output
[
extend_tokens_slice
]
self
.
extend_forward
(
attn_metadata
=
attn_metadata
,
query
=
extend_querys
,
key
=
extend_keys
,
value
=
extend_values
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
output
=
extend_outputs
,
cu_seqlens_q
=
attn_metadata
.
extend_metadata
.
query_start_loc
,
max_seqlen_q
=
attn_metadata
.
extend_metadata
.
max_query_len
,
max_seqlen_k
=
attn_metadata
.
extend_metadata
.
max_seq_len
,
min_seqlen_q
=
attn_metadata
.
extend_metadata
.
min_query_len
,
block_table
=
attn_metadata
.
block_table
[
num_decodes
:
num_decodes
+
num_extends
],
slot_mapping
=
attn_metadata
.
slot_mapping
[
num_decodes
:
num_decodes
+
num_extends
],
k_scale
=
layer
.
_k_scale
,
k_scale
=
layer
.
_k_scale
,
v_scale
=
layer
.
_v_scale
,
v_scale
=
layer
.
_v_scale
,
total_tokens
=
attn_metadata
.
num_actual_kv_tokens
,
)
)
# calculate for decodes
if
num_decodes
>
0
:
assert
attn_metadata
.
decode_metadata
is
not
None
_
,
num_heads
,
head_size
=
query
.
shape
_
,
num_heads
,
head_size
=
query
.
shape
nbytes_per_qo_elem
=
torch
.
finfo
(
query
.
dtype
).
bits
//
8
nbytes_per_qo_elem
=
torch
.
finfo
(
query
.
dtype
).
bits
//
8
num_seqs
=
seqused_k
.
shape
[
0
]
num_seqs
=
attn_metadata
.
seq_lens
.
shape
[
0
]
max_num_partitions
=
(
max_num_partitions
=
(
max_seqlen
_k
+
_PARTITION_SIZE_ROCM
-
1
attn_metadata
.
max_seq
_
len
+
_PARTITION_SIZE_ROCM
-
1
)
//
_PARTITION_SIZE_ROCM
)
//
_PARTITION_SIZE_ROCM
workspace_buffer
=
torch
.
empty
(
workspace_buffer
=
torch
.
empty
(
...
@@ -563,16 +813,16 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -563,16 +813,16 @@ class AiterFlashAttentionImpl(AttentionImpl):
)
)
torch
.
ops
.
aiter
.
paged_attention_v1
(
torch
.
ops
.
aiter
.
paged_attention_v1
(
output
[:
num_
actual
_tokens
],
output
[:
num_
decode
_tokens
],
workspace_buffer
,
workspace_buffer
,
query
[:
num_
actual
_tokens
],
query
[:
num_
decode
_tokens
],
key_cache
,
key_cache
,
value_cache
,
value_cache
,
self
.
scale
,
self
.
scale
,
block_table
,
attn_metadata
.
block_table
[:
num_decodes
]
,
cu_seqlens_q
,
attn_metadata
.
query_start_loc
[:
num_decodes
]
,
seqused_k
,
attn_metadata
.
seq_lens
[:
num_decodes
]
,
max_seqlen
_k
,
attn_metadata
.
max_seq
_
len
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
"NHD"
,
"NHD"
,
...
@@ -582,8 +832,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -582,8 +832,9 @@ class AiterFlashAttentionImpl(AttentionImpl):
None
,
None
,
_PARTITION_SIZE_ROCM
,
_PARTITION_SIZE_ROCM
,
)
)
return
output
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Cascade attention is not implemented for ROCM AITER"
"Cascade attention is not implemented for ROCM AITER"
)
)
return
output
vllm/v1/attention/backends/utils.py
View file @
dc937175
...
@@ -728,6 +728,73 @@ def subclass_attention_backend(
...
@@ -728,6 +728,73 @@ def subclass_attention_backend(
)
)
def
split_decodes_prefills_and_extends
(
common_attn_metadata
:
CommonAttentionMetadata
,
decode_threshold
:
int
=
1
,
)
->
tuple
[
int
,
int
,
int
,
int
,
int
,
int
]:
"""
Assuming a reordered batch, finds the boundary between prefill and decode
requests.
Args:
common_attn_metadata: CommonAttentionMetadata object containing the
batch metadata.
decode_threshold: The maximum query length to be considered a decode.
Returns:
num_decodes: The number of decode requests.
num_extends: The number of extend requests.
num_prefills: The number of prefill requests.
num_decode_tokens: The number of tokens in the decode requests.
num_extend_tokens: The number of tokens in the extend requests.
num_prefill_tokens: The number of tokens in the prefill requests.
"""
max_query_len
=
common_attn_metadata
.
max_query_len
num_reqs
=
common_attn_metadata
.
num_reqs
num_tokens
=
common_attn_metadata
.
num_actual_tokens
query_start_loc
=
common_attn_metadata
.
query_start_loc_cpu
seq_lens
=
common_attn_metadata
.
seq_lens_cpu
if
max_query_len
<=
decode_threshold
:
return
num_reqs
,
0
,
0
,
num_tokens
,
0
,
0
query_lens
=
query_start_loc
[
1
:]
-
query_start_loc
[:
-
1
]
is_prefill_or_extend
=
query_lens
>
decode_threshold
is_prefill
=
(
seq_lens
==
query_lens
)
&
is_prefill_or_extend
first_extend
=
is_prefill_or_extend
.
int
().
argmax
(
dim
=-
1
).
item
()
first_prefill
=
is_prefill
.
int
().
argmax
(
dim
=-
1
).
item
()
num_decodes
=
first_extend
num_decode_tokens
=
query_start_loc
[
first_extend
].
item
()
if
not
torch
.
any
(
is_prefill_or_extend
):
return
(
num_decodes
,
0
,
0
,
num_decode_tokens
,
0
,
0
)
num_prefills_or_extends
=
num_reqs
-
num_decodes
num_prefill_or_extend_tokens
=
num_tokens
-
num_decode_tokens
if
not
torch
.
any
(
is_prefill
):
return
(
num_decodes
,
num_prefills_or_extends
,
0
,
num_decode_tokens
,
num_prefill_or_extend_tokens
,
0
,
)
num_extends
=
first_prefill
-
num_decodes
num_prefills
=
num_reqs
-
first_prefill
num_prefill_tokens
=
num_tokens
-
query_start_loc
[
first_prefill
]
num_extend_tokens
=
num_prefill_or_extend_tokens
-
num_prefill_tokens
return
(
num_decodes
,
num_extends
,
num_prefills
,
num_decode_tokens
,
num_extend_tokens
,
num_prefill_tokens
,
)
def
split_decodes_and_prefills
(
def
split_decodes_and_prefills
(
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
decode_threshold
:
int
=
1
,
decode_threshold
:
int
=
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment