Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a44b1c95
Unverified
Commit
a44b1c95
authored
Jun 17, 2025
by
Charlie Fu
Committed by
GitHub
Jun 17, 2025
Browse files
[Feature][ROCm] Add full graph capture support for TritonAttentionBackend (#19158)
Signed-off-by:
charlifu
<
charlifu@amd.com
>
parent
b447624e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
334 additions
and
178 deletions
+334
-178
tests/compile/piecewise/test_full_cudagraph.py
tests/compile/piecewise/test_full_cudagraph.py
+1
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+3
-2
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+3
-169
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+159
-7
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+168
-0
No files found.
tests/compile/piecewise/test_full_cudagraph.py
View file @
a44b1c95
...
@@ -147,6 +147,7 @@ def test_lower_max_num_seqs(model, supported):
...
@@ -147,6 +147,7 @@ def test_lower_max_num_seqs(model, supported):
llm
.
generate
([
"Hello, my name is"
]
*
10
)
llm
.
generate
([
"Hello, my name is"
]
*
10
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_cuda
(),
reason
=
"Skip if not cuda"
)
def
test_full_cudagraph_with_invalid_backend
():
def
test_full_cudagraph_with_invalid_backend
():
with
temporary_environ
({
with
temporary_environ
({
"VLLM_USE_V1"
:
"1"
,
"VLLM_USE_V1"
:
"1"
,
...
...
vllm/platforms/rocm.py
View file @
a44b1c95
...
@@ -141,7 +141,8 @@ def use_rocm_custom_paged_attention(
...
@@ -141,7 +141,8 @@ def use_rocm_custom_paged_attention(
and
(
head_size
==
64
or
head_size
==
128
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
block_size
==
16
or
block_size
==
32
)
and
(
block_size
==
16
or
block_size
==
32
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
and
(
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
and
max_seq_len
<=
128
*
1024
and
(
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
and
not
(
envs
.
VLLM_ROCM_USE_AITER_PAGED_ATTN
and
not
(
envs
.
VLLM_ROCM_USE_AITER_PAGED_ATTN
and
envs
.
VLLM_ROCM_USE_AITER
))
and
envs
.
VLLM_ROCM_USE_AITER
))
...
@@ -151,7 +152,7 @@ def use_rocm_custom_paged_attention(
...
@@ -151,7 +152,7 @@ def use_rocm_custom_paged_attention(
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
head_size
==
128
and
block_size
==
16
and
head_size
==
128
and
block_size
==
16
and
(
gqa_ratio
>=
3
and
gqa_ratio
<=
16
)
and
(
gqa_ratio
>=
3
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
and
alibi_slopes
is
None
and
max_seq_len
<=
128
*
1024
and
alibi_slopes
is
None
and
kv_cache_dtype
==
"auto"
and
kv_cache_dtype
==
"auto"
and
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
and
envs
.
VLLM_ROCM_CUSTOM_PAGED_ATTN
)
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
a44b1c95
...
@@ -19,9 +19,9 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
...
@@ -19,9 +19,9 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
,
get_kv_cache_layout
)
make_local_attention_virtual_batches
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
...
@@ -126,172 +126,6 @@ class FlashAttentionMetadata:
...
@@ -126,172 +126,6 @@ class FlashAttentionMetadata:
local_attn_metadata
:
Optional
[
LocalAttentionMetadata
]
=
None
local_attn_metadata
:
Optional
[
LocalAttentionMetadata
]
=
None
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def
make_local_attention_virtual_batches
(
attn_chunk_size
:
int
,
query_start_loc_np
:
np
.
ndarray
,
seq_lens_np
:
np
.
ndarray
,
block_table
:
torch
.
Tensor
,
block_size
:
int
=
0
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
Tensor
]:
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block
=
np
.
minimum
(
attn_chunk_size
-
((
seq_lens_np
-
q_seqlens
)
%
attn_chunk_size
),
q_seqlens
).
astype
(
np
.
int32
)
tokens_in_last_block
=
attn_chunk_size
+
(
seq_lens_np
%
-
attn_chunk_size
)
local_blocks
=
1
+
cdiv
(
q_seqlens
-
q_tokens_in_first_block
,
attn_chunk_size
)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks
=
np
.
cumsum
(
local_blocks
)
virtual_batches
=
cu_num_blocks
[
-
1
]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
local_blocks
,
local_blocks
)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange
=
np
.
arange
(
virtual_batches
,
dtype
=
np
.
int32
)
-
block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange
=
np
.
repeat
(
local_blocks
,
local_blocks
)
-
arange
-
1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local
=
\
np
.
repeat
(
q_seqlens
-
q_tokens_in_first_block
,
local_blocks
)
# set the first block since this may be a partial block
seqlens_q_local
[
arange
==
0
]
=
q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local
[
arange
>
0
]
=
np
.
minimum
(
seqlens_q_local
-
attn_chunk_size
*
(
arange
-
1
),
attn_chunk_size
)[
arange
>
0
]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local
=
np
.
pad
(
np
.
cumsum
(
seqlens_q_local
),
(
1
,
0
))
\
.
astype
(
np
.
int32
)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local
=
np
.
full
(
cu_num_blocks
[
-
1
],
attn_chunk_size
,
dtype
=
np
.
int32
)
seqlens_k_local
[
cu_num_blocks
-
1
]
=
tokens_in_last_block
k_seqstarts_absolute
=
np
.
repeat
(
seq_lens_np
,
local_blocks
)
-
\
(
rarange
*
attn_chunk_size
+
\
np
.
repeat
(
tokens_in_last_block
,
local_blocks
))
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts
=
k_seqstarts_absolute
//
block_size
assert
attn_chunk_size
%
block_size
==
0
,
\
f
"attn_chunk_size
{
attn_chunk_size
}
is not "
\
f
"divisible by block_size
{
block_size
}
"
pages_per_local_batch
=
attn_chunk_size
//
block_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming block_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices
=
np
.
broadcast_to
(
np
.
arange
(
pages_per_local_batch
,
dtype
=
np
.
int32
),
(
virtual_batches
,
pages_per_local_batch
))
\
+
np
.
expand_dims
(
block_starts
,
axis
=
1
)
block_indices
=
block_indices
.
flatten
().
clip
(
max
=
block_table
.
shape
[
1
]
-
1
)
batch_indices
=
np
.
repeat
(
np
.
arange
(
actual_batch_size
,
dtype
=
np
.
int32
),
local_blocks
*
pages_per_local_batch
)
block_table_local
=
block_table
[
batch_indices
,
block_indices
]
\
.
view
(
virtual_batches
,
-
1
)
return
seqlens_q_local
,
cu_seqlens_q_local
,
seqlens_k_local
,
\
block_table_local
def
_get_sliding_window_configs
(
def
_get_sliding_window_configs
(
vllm_config
:
VllmConfig
)
->
set
[
Optional
[
tuple
[
int
,
int
]]]:
vllm_config
:
VllmConfig
)
->
set
[
Optional
[
tuple
[
int
,
int
]]]:
"""Get the set of all sliding window configs used in the model."""
"""Get the set of all sliding window configs used in the model."""
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
a44b1c95
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
"""Attention layer with PagedAttention and Triton prefix prefill."""
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
import
torch
import
torch
...
@@ -15,8 +16,10 @@ from vllm.attention.ops.paged_attn import PagedAttention
...
@@ -15,8 +16,10 @@ from vllm.attention.ops.paged_attn import PagedAttention
from
vllm.attention.ops.triton_unified_attention
import
unified_attention
from
vllm.attention.ops.triton_unified_attention
import
unified_attention
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.flash_attn
import
(
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
FlashAttentionMetadata
,
FlashAttentionMetadataBuilder
)
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
make_local_attention_virtual_batches
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
...
@@ -26,12 +29,161 @@ if TYPE_CHECKING:
...
@@ -26,12 +29,161 @@ if TYPE_CHECKING:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
TritonAttentionMetadataBuilder
(
FlashAttentionMetadataBuilder
):
@
dataclass
class
TritonAttentionMetadata
:
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens
:
int
# Number of tokens excluding padding.
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
max_seq_len
:
int
seq_lens
:
torch
.
Tensor
block_table
:
torch
.
Tensor
slot_mapping
:
torch
.
Tensor
# For cascade attention.
use_cascade
:
bool
common_prefix_len
:
int
cu_prefix_query_lens
:
Optional
[
torch
.
Tensor
]
prefix_kv_lens
:
Optional
[
torch
.
Tensor
]
suffix_kv_lens
:
Optional
[
torch
.
Tensor
]
# Optional aot scheduling
scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
# for local attention
@
dataclass
class
LocalAttentionMetadata
:
local_query_start_loc
:
torch
.
Tensor
local_seqused_k
:
torch
.
Tensor
local_block_table
:
torch
.
Tensor
local_max_query_len
:
int
local_max_seq_len
:
int
local_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
local_attn_metadata
:
Optional
[
LocalAttentionMetadata
]
=
None
class
TritonAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
TritonAttentionMetadata
]):
full_cudagraph_supported
:
ClassVar
[
bool
]
=
True
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
AttentionSpec
,
def
__init__
(
self
,
runner
:
"GPUModelRunner"
,
kv_cache_spec
:
AttentionSpec
,
block_table
:
BlockTable
):
block_table
:
BlockTable
):
super
().
__init__
(
runner
,
kv_cache_spec
,
block_table
)
self
.
runner
=
runner
self
.
aot_schedule
=
False
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
kv_cache_spec
=
kv_cache_spec
self
.
block_table
=
block_table
def
build_for_cudagraph_capture
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
TritonAttentionMetadata
:
attn_metadata
=
self
.
build
(
0
,
common_attn_metadata
)
# When doing full graph capture, setting seq_lens to
# max_model_len will cause graph capture to be extremely
# slow, so here we set it to 1.
attn_metadata
.
seq_lens
.
fill_
(
1
)
return
attn_metadata
def
build
(
self
,
common_prefix_len
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
TritonAttentionMetadata
:
num_reqs
=
common_attn_metadata
.
num_reqs
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
max_query_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
int
(
self
.
runner
.
seq_lens_np
[:
num_reqs
].
max
())
query_start_loc
=
common_attn_metadata
.
query_start_loc
seq_lens
=
common_attn_metadata
.
seq_lens
block_table
=
self
.
block_table
block_table_tensor
=
block_table
.
get_device_tensor
()[:
num_reqs
]
block_table
.
slot_mapping
[:
num_actual_tokens
].
copy_
(
block_table
.
slot_mapping_cpu
[:
num_actual_tokens
],
non_blocking
=
True
)
# Fill unused with -1. Needed for reshape_and_cache in full cuda graph
# mode.
block_table
.
slot_mapping
[
num_actual_tokens
:].
fill_
(
-
1
)
slot_mapping
=
block_table
.
slot_mapping
[:
num_actual_tokens
]
# for local attention
local_attn_metadata
=
None
if
self
.
runner
.
attention_chunk_size
is
not
None
:
seqlens_q_local_np
,
virt_q_cu_seqlens_np
,
virt_k_seqlens_np
,
\
virt_block_table_tensor
=
make_local_attention_virtual_batches
(
self
.
runner
.
attention_chunk_size
,
self
.
runner
.
query_start_loc_np
[:
num_reqs
+
1
],
self
.
runner
.
seq_lens_np
[:
num_reqs
],
block_table_tensor
,
self
.
block_size
,
)
local_query_start_loc
=
torch
.
from_numpy
(
virt_q_cu_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
local_seqused_k
=
torch
.
from_numpy
(
virt_k_seqlens_np
).
to
(
self
.
runner
.
device
,
non_blocking
=
True
)
local_max_query_len
=
seqlens_q_local_np
.
max
()
local_max_seq_len
=
virt_k_seqlens_np
.
max
()
local_attn_metadata
=
TritonAttentionMetadata
\
.
LocalAttentionMetadata
(
local_query_start_loc
=
local_query_start_loc
,
local_seqused_k
=
local_seqused_k
,
local_block_table
=
virt_block_table_tensor
,
local_max_query_len
=
local_max_query_len
,
local_max_seq_len
=
local_max_seq_len
,
local_scheduler_metadata
=
None
,
)
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
cu_prefix_query_lens
=
torch
.
tensor
([
0
,
num_actual_tokens
],
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
prefix_kv_lens
=
torch
.
tensor
([
common_prefix_len
],
dtype
=
torch
.
int32
,
device
=
self
.
runner
.
device
)
suffix_kv_lens
=
(
self
.
runner
.
seq_lens_np
[:
num_reqs
]
-
common_prefix_len
)
suffix_kv_lens
=
torch
.
from_numpy
(
suffix_kv_lens
).
to
(
self
.
runner
.
device
)
else
:
cu_prefix_query_lens
=
None
prefix_kv_lens
=
None
suffix_kv_lens
=
None
prefix_scheduler_metadata
=
None
attn_metadata
=
TritonAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
seq_lens
=
seq_lens
,
block_table
=
block_table_tensor
,
slot_mapping
=
slot_mapping
,
use_cascade
=
use_cascade
,
common_prefix_len
=
common_prefix_len
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
)
return
attn_metadata
def
can_run_in_cudagraph
(
self
,
common_attn_metadata
:
CommonAttentionMetadata
)
->
bool
:
# Full CUDA Graph always supported
return
True
class
TritonAttentionBackend
(
AttentionBackend
):
class
TritonAttentionBackend
(
AttentionBackend
):
...
@@ -52,7 +204,7 @@ class TritonAttentionBackend(AttentionBackend):
...
@@ -52,7 +204,7 @@ class TritonAttentionBackend(AttentionBackend):
@
staticmethod
@
staticmethod
def
get_metadata_cls
()
->
type
[
"AttentionMetadata"
]:
def
get_metadata_cls
()
->
type
[
"AttentionMetadata"
]:
return
Flash
AttentionMetadata
return
Triton
AttentionMetadata
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
...
vllm/v1/attention/backends/utils.py
View file @
a44b1c95
...
@@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
...
@@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
vllm.utils
import
cdiv
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
from
vllm.v1.worker.gpu_input_batch
import
InputBatch
...
@@ -140,3 +142,169 @@ def get_kv_cache_layout():
...
@@ -140,3 +142,169 @@ def get_kv_cache_layout():
"detected. Setting KV cache layout to %s."
,
cache_layout
)
"detected. Setting KV cache layout to %s."
,
cache_layout
)
return
cache_layout
return
cache_layout
#
# Take in `query_start_loc_np` and `seq_lens_np` and break the sequences into
# local attention blocks, where each block is passed to the attention kernel
# as an independent local ("virtual") batch item.
#
# For example, if are performing a chunked prefill a batch of 3 sequences:
# q_seqlens = [4, 10, 5]
# kv_seqlens = [6, 17, 9]
# Then normally for regular attention we would compute with an attention mask
# for batch idx 0 (q_seqlens = 4, kv_seqlens = 6) like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1 1 1 1 1
# 3 | 1 1 1 1 1 1
#
# for local attention (with attn_chunk_size = 4) we would compute with an
# attention mask like:
# batch idx: 0 (q_seqlens = 4, kv_seqlens = 6, attn_chunk_size = 4)
# k_toks > 0 1 2 3 4 5
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# 2 | 1
# 3 | 1 1
#
# We can simulate this mask using standard flash-attention by breaking the
# sequences into local ("virtual") batches, where each local batch item is a
# local attention block, so in this case batch idx 0 would be broken up into:
#
# local-batch idx: 0 (q_seqlens = 2, kv_seqlens = 4) (batch 0)
# k_toks > 0 1 2 3
# q_toks v _____________
# 0 | 1 1 1
# 1 | 1 1 1 1
# local-batch idx: 1 (q_seqlens = 2, kv_seqlens = 2) (batch 0)
# k_toks > 4 5
# q_toks v _____________
# 2 | 1
# 3 | 1 1
#
# e.g. if we have:
# attn_chunk_size = 4
# query_start_loc_np = [0, 4, 14, 19] (q_seqlens = [4, 10, 5])
# Then this function would return:
# __b0__ ______b1______ __b2__ < orig batch indices
# q_seqlens_local = [ 2, 2, 1, 4, 4, 1, 4, 1]
# cu_seqlens_q_local = [0, 4, 6, 10, 14, 18, 19, 23, 24]
# seqlens_k_local = [ 4, 2, 4, 4, 4, 1, 4, 1]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def
make_local_attention_virtual_batches
(
attn_chunk_size
:
int
,
query_start_loc_np
:
np
.
ndarray
,
seq_lens_np
:
np
.
ndarray
,
block_table
:
torch
.
Tensor
,
block_size
:
int
=
0
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
Tensor
]:
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
# Handle if we are starting in the middle of a local attention block,
# we assume q_seqlens > 0 (for all elements), for each batch idx we compute
# the number of tokens that are not in the first local attention block and
# then we can simply use a cdiv for the rest.
# For example if we have:
# attn_chunk_size = 4
# q_seqlens = [4, 10, 5]
# k_seqlens = [6, 17, 9]
# Then we would get:
# new_tokens_in_first_block = [2, 1, 4]
# local_blocks = [2, 4, 2]
q_tokens_in_first_block
=
np
.
minimum
(
attn_chunk_size
-
((
seq_lens_np
-
q_seqlens
)
%
attn_chunk_size
),
q_seqlens
).
astype
(
np
.
int32
)
tokens_in_last_block
=
attn_chunk_size
+
(
seq_lens_np
%
-
attn_chunk_size
)
local_blocks
=
1
+
cdiv
(
q_seqlens
-
q_tokens_in_first_block
,
attn_chunk_size
)
# Once we know the number of local blocks we can compute the request spans
# for each batch idx, we can figure out the number of "virtual" requests we
# have to make,
# For the above example we would get:
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
#
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks
=
np
.
cumsum
(
local_blocks
)
virtual_batches
=
cu_num_blocks
[
-
1
]
# arange step 2. [2, 6, 8] -> [0, 0, 2, 2, 2, 2, 6, 6]
block_offsets
=
np
.
repeat
(
cu_num_blocks
-
local_blocks
,
local_blocks
)
# arange step 3. [0, 1, 0, 1, 2, 3, 0, 1]
arange
=
np
.
arange
(
virtual_batches
,
dtype
=
np
.
int32
)
-
block_offsets
# also compute reverse arange (i.e. [1, 0, 3, 2, 1, 0, 1, 0])
rarange
=
np
.
repeat
(
local_blocks
,
local_blocks
)
-
arange
-
1
# Then we can compute the seqlens_q_local, handling the fact that the
# first and last blocks could be partial
seqlens_q_local
=
\
np
.
repeat
(
q_seqlens
-
q_tokens_in_first_block
,
local_blocks
)
# set the first block since this may be a partial block
seqlens_q_local
[
arange
==
0
]
=
q_tokens_in_first_block
# set the remaining blocks
seqlens_q_local
[
arange
>
0
]
=
np
.
minimum
(
seqlens_q_local
-
attn_chunk_size
*
(
arange
-
1
),
attn_chunk_size
)[
arange
>
0
]
# convert from q_seqlens to cu_seqlens_q
cu_seqlens_q_local
=
np
.
pad
(
np
.
cumsum
(
seqlens_q_local
),
(
1
,
0
))
\
.
astype
(
np
.
int32
)
# compute the seqlens_k_local,
# basically a full local attention block for all but the last block in each
# batch
# For our example this will be:
# seqlens_k_local = [4, 2, 4, 4, 4, 1, 4, 1]
seqlens_k_local
=
np
.
full
(
cu_num_blocks
[
-
1
],
attn_chunk_size
,
dtype
=
np
.
int32
)
seqlens_k_local
[
cu_num_blocks
-
1
]
=
tokens_in_last_block
k_seqstarts_absolute
=
np
.
repeat
(
seq_lens_np
,
local_blocks
)
-
\
(
rarange
*
attn_chunk_size
+
\
np
.
repeat
(
tokens_in_last_block
,
local_blocks
))
# For the example the local attention blocks start at:
# _b0_ _____b1_____ _b2_
# k_seqstarts_absolute = [0, 4, 4, 8, 12, 16, 4, 8]
block_starts
=
k_seqstarts_absolute
//
block_size
assert
attn_chunk_size
%
block_size
==
0
,
\
f
"attn_chunk_size
{
attn_chunk_size
}
is not "
\
f
"divisible by block_size
{
block_size
}
"
pages_per_local_batch
=
attn_chunk_size
//
block_size
# Create a block_table for the local attention blocks
# For out example if we have a block-table like (assuming block_size=2):
# block_table = [
# [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9], < batch 0
# [10, 11, 12, 13, 14, 15, 16, 17, 18, 19], < batch 1
# [20, 21, 22, 23, 24, 25, 26, 27, 28, 29], < batch 2
# ]
# Then for the local batches we would want a block-table like
# block_table_local = [
# [ 0, 1 ], < local-batch 0, (batch 0, starting from k[0])
# [ 2, 3 ], < local-batch 1, (batch 0, starting from k[4])
# [ 12, 13 ], < local-batch 2, (batch 1, starting from k[4])
# [ 14, 15 ], < local-batch 3, (batch 1, starting from k[8])
# [ 16, 17 ], < local-batch 4, (batch 1, starting from k[12])
# [ 18, 19 ], < local-batch 5, (batch 1, starting from k[16])
# [ 22, 23 ], < local-batch 6, (batch 2, starting from k[4])
# [ 24, 25 ], < local-batch 7, (batch 2, starting from k[8])
# ]
block_indices
=
np
.
broadcast_to
(
np
.
arange
(
pages_per_local_batch
,
dtype
=
np
.
int32
),
(
virtual_batches
,
pages_per_local_batch
))
\
+
np
.
expand_dims
(
block_starts
,
axis
=
1
)
block_indices
=
block_indices
.
flatten
().
clip
(
max
=
block_table
.
shape
[
1
]
-
1
)
batch_indices
=
np
.
repeat
(
np
.
arange
(
actual_batch_size
,
dtype
=
np
.
int32
),
local_blocks
*
pages_per_local_batch
)
block_table_local
=
block_table
[
batch_indices
,
block_indices
]
\
.
view
(
virtual_batches
,
-
1
)
return
seqlens_q_local
,
cu_seqlens_q_local
,
seqlens_k_local
,
\
block_table_local
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment