Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
89cab4d0
Unverified
Commit
89cab4d0
authored
Jul 18, 2025
by
Lucas Wilkinson
Committed by
GitHub
Jul 18, 2025
Browse files
[Attention] Make local attention backend agnostic (#21093)
parent
b9a21e91
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
94 additions
and
242 deletions
+94
-242
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+10
-74
vllm/v1/attention/backends/flashinfer.py
vllm/v1/attention/backends/flashinfer.py
+1
-4
vllm/v1/attention/backends/rocm_aiter_fa.py
vllm/v1/attention/backends/rocm_aiter_fa.py
+7
-90
vllm/v1/attention/backends/triton_attn.py
vllm/v1/attention/backends/triton_attn.py
+7
-61
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+24
-6
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+7
-3
vllm/v1/kv_cache_interface.py
vllm/v1/kv_cache_interface.py
+15
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+23
-4
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
89cab4d0
...
@@ -25,9 +25,9 @@ if is_flash_attn_varlen_func_available():
...
@@ -25,9 +25,9 @@ if is_flash_attn_varlen_func_available():
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.config
import
VllmConfig
,
get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
,
CommonAttentionMetadata
,
make_local_attention_virtual_batches
)
get_kv_cache_layout
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -130,18 +130,6 @@ class FlashAttentionMetadata:
...
@@ -130,18 +130,6 @@ class FlashAttentionMetadata:
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
max_num_splits
:
int
=
0
max_num_splits
:
int
=
0
# for local attention
@
dataclass
class
LocalAttentionMetadata
:
local_query_start_loc
:
torch
.
Tensor
local_seqused_k
:
torch
.
Tensor
local_block_table
:
torch
.
Tensor
local_max_query_len
:
int
local_max_seq_len
:
int
local_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
local_attn_metadata
:
Optional
[
LocalAttentionMetadata
]
=
None
def
_get_sliding_window_configs
(
def
_get_sliding_window_configs
(
vllm_config
:
VllmConfig
)
->
set
[
Optional
[
tuple
[
int
,
int
]]]:
vllm_config
:
VllmConfig
)
->
set
[
Optional
[
tuple
[
int
,
int
]]]:
...
@@ -221,7 +209,6 @@ class FlashAttentionMetadataBuilder(
...
@@ -221,7 +209,6 @@ class FlashAttentionMetadataBuilder(
max_query_len
=
common_attn_metadata
.
max_query_len
max_query_len
=
common_attn_metadata
.
max_query_len
max_seq_len
=
int
(
common_attn_metadata
.
seq_lens_cpu
.
max
())
max_seq_len
=
int
(
common_attn_metadata
.
seq_lens_cpu
.
max
())
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
...
@@ -266,40 +253,6 @@ class FlashAttentionMetadataBuilder(
...
@@ -266,40 +253,6 @@ class FlashAttentionMetadataBuilder(
)
)
return
None
return
None
# for local attention
local_attn_metadata
=
None
if
self
.
model_config
.
attention_chunk_size
is
not
None
:
seqlens_q_local_np
,
virt_q_cu_seqlens_np
,
virt_k_seqlens_np
,
\
virt_block_table_tensor
=
make_local_attention_virtual_batches
(
self
.
model_config
.
attention_chunk_size
,
query_start_loc_cpu
.
numpy
(),
seq_lens_cpu
.
numpy
(),
block_table_tensor
,
self
.
block_size
,
)
local_query_start_loc
=
torch
.
from_numpy
(
virt_q_cu_seqlens_np
).
to
(
self
.
device
,
non_blocking
=
True
)
local_seqused_k
=
torch
.
from_numpy
(
virt_k_seqlens_np
).
to
(
self
.
device
,
non_blocking
=
True
)
local_max_query_len
=
seqlens_q_local_np
.
max
()
local_max_seq_len
=
virt_k_seqlens_np
.
max
()
local_scheduler_metadata
=
schedule
(
batch_size
=
local_query_start_loc
.
shape
[
0
]
-
1
,
cu_query_lens
=
local_query_start_loc
,
max_query_len
=
local_max_query_len
,
seqlens
=
local_seqused_k
,
max_seq_len
=
local_max_seq_len
,
causal
=
True
)
local_attn_metadata
=
FlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
local_query_start_loc
,
local_seqused_k
=
local_seqused_k
,
local_block_table
=
virt_block_table_tensor
,
local_max_query_len
=
local_max_query_len
,
local_max_seq_len
=
local_max_seq_len
,
local_scheduler_metadata
=
local_scheduler_metadata
,
)
use_cascade
=
common_prefix_len
>
0
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
if
use_cascade
:
...
@@ -371,7 +324,6 @@ class FlashAttentionMetadataBuilder(
...
@@ -371,7 +324,6 @@ class FlashAttentionMetadataBuilder(
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
max_num_splits
=
max_num_splits
,
max_num_splits
=
max_num_splits
,
)
)
...
@@ -517,27 +469,13 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -517,27 +469,13 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_q_scale
)
layer
.
_q_scale
)
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
# Compute attention and update output up to `num_actual_tokens`.
if
not
attn_metadata
.
use_cascade
:
use_local_attn
=
\
cu_seqlens_q
=
attn_metadata
.
query_start_loc
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
seqused_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
if
not
attn_metadata
.
use_cascade
or
use_local_attn
:
max_seqlen_k
=
attn_metadata
.
max_seq_len
if
use_local_attn
:
block_table
=
attn_metadata
.
block_table
assert
attn_metadata
.
local_attn_metadata
is
not
None
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
seqused_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
scheduler_metadata
=
local_metadata
.
local_scheduler_metadata
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
seqused_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
scheduler_metadata
=
attn_metadata
.
scheduler_metadata
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
descale_shape
=
(
cu_seqlens_q
.
shape
[
0
]
-
1
,
key
.
shape
[
1
])
...
@@ -565,8 +503,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -565,8 +503,6 @@ class FlashAttentionImpl(AttentionImpl):
)
)
return
output
return
output
assert
not
use_local_attn
,
(
"Cascade attention does not support local attention."
)
# Cascade attention (rare case).
# Cascade attention (rare case).
cascade_attention
(
cascade_attention
(
output
[:
num_actual_tokens
],
output
[:
num_actual_tokens
],
...
...
vllm/v1/attention/backends/flashinfer.py
View file @
89cab4d0
...
@@ -496,10 +496,6 @@ class FlashInferImpl(AttentionImpl):
...
@@ -496,10 +496,6 @@ class FlashInferImpl(AttentionImpl):
kv_sharing_target_layer_name
:
Optional
[
int
]
=
None
,
kv_sharing_target_layer_name
:
Optional
[
int
]
=
None
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
)
->
None
:
)
->
None
:
if
use_irope
:
logger
.
warning_once
(
"Using irope in FlashInfer is not supported yet, it will fall"
" back to global attention for long context."
)
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_size
=
head_size
self
.
head_size
=
head_size
self
.
scale
=
float
(
scale
)
self
.
scale
=
float
(
scale
)
...
@@ -514,6 +510,7 @@ class FlashInferImpl(AttentionImpl):
...
@@ -514,6 +510,7 @@ class FlashInferImpl(AttentionImpl):
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
logits_soft_cap
=
logits_soft_cap
self
.
logits_soft_cap
=
logits_soft_cap
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
self
.
kv_sharing_target_layer_name
=
kv_sharing_target_layer_name
self
.
use_irope
=
use_irope
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
self
.
num_queries_per_kv
=
self
.
num_heads
//
self
.
num_kv_heads
...
...
vllm/v1/attention/backends/rocm_aiter_fa.py
View file @
89cab4d0
...
@@ -13,8 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
...
@@ -13,8 +13,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.flash_attn
import
(
make_local_attention_virtual_batches
)
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.attention.backends.utils
import
CommonAttentionMetadata
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
@@ -201,9 +199,7 @@ class AiterFlashAttentionMetadataBuilder:
...
@@ -201,9 +199,7 @@ class AiterFlashAttentionMetadataBuilder:
max_seq_len
=
int
(
common_attn_metadata
.
seq_lens_cpu
.
max
())
max_seq_len
=
int
(
common_attn_metadata
.
seq_lens_cpu
.
max
())
total_tokens
=
int
(
common_attn_metadata
.
seq_lens_cpu
.
sum
())
total_tokens
=
int
(
common_attn_metadata
.
seq_lens_cpu
.
sum
())
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc
=
common_attn_metadata
.
query_start_loc
query_start_loc_cpu
=
common_attn_metadata
.
query_start_loc_cpu
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens
=
common_attn_metadata
.
seq_lens
seq_lens_cpu
=
common_attn_metadata
.
seq_lens_cpu
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
slot_mapping
=
common_attn_metadata
.
slot_mapping
...
@@ -215,56 +211,6 @@ class AiterFlashAttentionMetadataBuilder:
...
@@ -215,56 +211,6 @@ class AiterFlashAttentionMetadataBuilder:
dtype
=
cu_seq_lens
.
dtype
,
dtype
=
cu_seq_lens
.
dtype
,
out
=
cu_seq_lens
[
1
:])
out
=
cu_seq_lens
[
1
:])
def
schedule
(
batch_size
,
cu_query_lens
,
max_query_len
,
seqlens
,
max_seq_len
,
causal
):
return
None
# for local attention
local_attn_metadata
=
None
if
self
.
model_config
.
attention_chunk_size
is
not
None
:
seqlens_q_local_np
,
virt_q_cu_seqlens_np
,
virt_k_seqlens_np
,
\
virt_block_table_tensor
=
make_local_attention_virtual_batches
(
self
.
model_config
.
attention_chunk_size
,
query_start_loc_cpu
.
numpy
(),
seq_lens_cpu
.
numpy
(),
block_table_tensor
,
self
.
block_size
,
)
local_query_start_loc
=
torch
.
from_numpy
(
virt_q_cu_seqlens_np
).
to
(
self
.
device
,
non_blocking
=
True
)
local_seqused_k
=
torch
.
from_numpy
(
virt_k_seqlens_np
).
to
(
self
.
device
,
non_blocking
=
True
)
local_max_query_len
=
seqlens_q_local_np
.
max
().
item
()
local_max_seq_len
=
virt_k_seqlens_np
.
max
().
item
()
local_scheduler_metadata
=
schedule
(
batch_size
=
local_query_start_loc
.
shape
[
0
]
-
1
,
cu_query_lens
=
local_query_start_loc
,
max_query_len
=
local_max_query_len
,
seqlens
=
local_seqused_k
,
max_seq_len
=
local_max_seq_len
,
causal
=
True
)
local_cu_seq_lens
=
torch
.
zeros
(
virt_k_seqlens_np
.
shape
[
0
]
+
1
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
local_cu_seq_lens
[
1
:]
=
torch
.
cumsum
(
torch
.
from_numpy
(
virt_k_seqlens_np
).
to
(
device
=
self
.
device
,
dtype
=
torch
.
int32
,
non_blocking
=
True
),
dim
=
0
)
local_attn_metadata
=
\
AiterFlashAttentionMetadata
.
LocalAttentionMetadata
(
local_query_start_loc
=
local_query_start_loc
,
local_seqused_k
=
local_seqused_k
,
local_block_table
=
virt_block_table_tensor
,
local_max_query_len
=
local_max_query_len
,
local_max_seq_len
=
local_max_seq_len
,
local_cu_seq_lens
=
local_cu_seq_lens
,
local_scheduler_metadata
=
local_scheduler_metadata
,
)
use_cascade
=
common_prefix_len
>
0
use_cascade
=
common_prefix_len
>
0
cu_prefix_query_lens
=
None
cu_prefix_query_lens
=
None
...
@@ -286,7 +232,6 @@ class AiterFlashAttentionMetadataBuilder:
...
@@ -286,7 +232,6 @@ class AiterFlashAttentionMetadataBuilder:
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
)
)
return
attn_metadata
return
attn_metadata
...
@@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata:
...
@@ -377,19 +322,6 @@ class AiterFlashAttentionMetadata:
prefix_kv_lens
:
Optional
[
torch
.
Tensor
]
prefix_kv_lens
:
Optional
[
torch
.
Tensor
]
suffix_kv_lens
:
Optional
[
torch
.
Tensor
]
suffix_kv_lens
:
Optional
[
torch
.
Tensor
]
# for local attention
@
dataclass
class
LocalAttentionMetadata
:
local_query_start_loc
:
torch
.
Tensor
local_seqused_k
:
torch
.
Tensor
local_block_table
:
torch
.
Tensor
local_max_query_len
:
int
local_max_seq_len
:
int
local_cu_seq_lens
:
torch
.
Tensor
local_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
local_attn_metadata
:
Optional
[
LocalAttentionMetadata
]
=
None
class
AiterFlashAttentionImpl
(
AttentionImpl
):
class
AiterFlashAttentionImpl
(
AttentionImpl
):
...
@@ -521,25 +453,12 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -521,25 +453,12 @@ class AiterFlashAttentionImpl(AttentionImpl):
layer
.
_q_scale
)
layer
.
_q_scale
)
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
# Compute attention and update output up to `num_actual_tokens`.
if
not
attn_metadata
.
use_cascade
:
use_local_attn
=
\
cu_seqlens_q
=
attn_metadata
.
query_start_loc
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
seqused_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
if
not
attn_metadata
.
use_cascade
or
use_local_attn
:
max_seqlen_k
=
attn_metadata
.
max_seq_len
if
use_local_attn
:
block_table
=
attn_metadata
.
block_table
assert
attn_metadata
.
local_attn_metadata
is
not
None
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
seqused_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
seqused_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
if
max_seqlen_q
>
1
:
if
max_seqlen_q
>
1
:
cu_seq_lens
=
attn_metadata
.
cu_seq_lens
cu_seq_lens
=
attn_metadata
.
cu_seq_lens
...
@@ -557,9 +476,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
...
@@ -557,9 +476,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
alibi_slopes
=
self
.
alibi_slopes
,
alibi_slopes
=
self
.
alibi_slopes
,
window_size
=
self
.
sliding_window
,
window_size
=
self
.
sliding_window
,
block_table
=
block_table
,
block_table
=
block_table
,
cu_seqlens_k
=
(
cu_seq_lens
if
not
use_local_attn
else
cu_seqlens_k
=
cu_seq_lens
)
local_metadata
.
local_cu_seq_lens
),
)
_
,
num_heads
,
head_size
=
query
.
shape
_
,
num_heads
,
head_size
=
query
.
shape
_PARTITION_SIZE_ROCM
=
256
_PARTITION_SIZE_ROCM
=
256
...
...
vllm/v1/attention/backends/triton_attn.py
View file @
89cab4d0
...
@@ -18,9 +18,8 @@ from vllm.config import VllmConfig
...
@@ -18,9 +18,8 @@ from vllm.config import VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionMetadata
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
)
make_local_attention_virtual_batches
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -55,18 +54,6 @@ class TritonAttentionMetadata:
...
@@ -55,18 +54,6 @@ class TritonAttentionMetadata:
scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
prefix_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
=
None
# for local attention
@
dataclass
class
LocalAttentionMetadata
:
local_query_start_loc
:
torch
.
Tensor
local_seqused_k
:
torch
.
Tensor
local_block_table
:
torch
.
Tensor
local_max_query_len
:
int
local_max_seq_len
:
int
local_scheduler_metadata
:
Optional
[
torch
.
Tensor
]
local_attn_metadata
:
Optional
[
LocalAttentionMetadata
]
=
None
class
TritonAttentionMetadataBuilder
(
class
TritonAttentionMetadataBuilder
(
AttentionMetadataBuilder
[
TritonAttentionMetadata
]):
AttentionMetadataBuilder
[
TritonAttentionMetadata
]):
...
@@ -111,34 +98,6 @@ class TritonAttentionMetadataBuilder(
...
@@ -111,34 +98,6 @@ class TritonAttentionMetadataBuilder(
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
block_table_tensor
=
common_attn_metadata
.
block_table_tensor
slot_mapping
=
common_attn_metadata
.
slot_mapping
slot_mapping
=
common_attn_metadata
.
slot_mapping
# for local attention
local_attn_metadata
=
None
if
self
.
attention_chunk_size
is
not
None
:
seqlens_q_local_np
,
virt_q_cu_seqlens_np
,
virt_k_seqlens_np
,
\
virt_block_table_tensor
=
make_local_attention_virtual_batches
(
self
.
attention_chunk_size
,
common_attn_metadata
.
query_start_loc_cpu
.
numpy
(),
common_attn_metadata
.
seq_lens_cpu
.
numpy
(),
block_table_tensor
,
self
.
block_size
,
)
local_query_start_loc
=
torch
.
from_numpy
(
virt_q_cu_seqlens_np
).
to
(
self
.
device
,
non_blocking
=
True
)
local_seqused_k
=
torch
.
from_numpy
(
virt_k_seqlens_np
).
to
(
self
.
device
,
non_blocking
=
True
)
local_max_query_len
=
seqlens_q_local_np
.
max
().
item
()
local_max_seq_len
=
virt_k_seqlens_np
.
max
().
item
()
local_attn_metadata
=
TritonAttentionMetadata
\
.
LocalAttentionMetadata
(
local_query_start_loc
=
local_query_start_loc
,
local_seqused_k
=
local_seqused_k
,
local_block_table
=
virt_block_table_tensor
,
local_max_query_len
=
local_max_query_len
,
local_max_seq_len
=
local_max_seq_len
,
local_scheduler_metadata
=
None
,
)
use_cascade
=
common_prefix_len
>
0
use_cascade
=
common_prefix_len
>
0
if
use_cascade
:
if
use_cascade
:
...
@@ -170,7 +129,6 @@ class TritonAttentionMetadataBuilder(
...
@@ -170,7 +129,6 @@ class TritonAttentionMetadataBuilder(
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
local_attn_metadata
=
local_attn_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
)
)
return
attn_metadata
return
attn_metadata
...
@@ -384,23 +342,11 @@ class TritonAttentionImpl(AttentionImpl):
...
@@ -384,23 +342,11 @@ class TritonAttentionImpl(AttentionImpl):
layer
.
_q_scale
)
layer
.
_q_scale
)
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
use_local_attn
=
\
cu_seqlens_q
=
attn_metadata
.
query_start_loc
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
seqused_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
if
use_local_attn
:
max_seqlen_k
=
attn_metadata
.
max_seq_len
assert
attn_metadata
.
local_attn_metadata
is
not
None
block_table
=
attn_metadata
.
block_table
local_metadata
=
attn_metadata
.
local_attn_metadata
cu_seqlens_q
=
local_metadata
.
local_query_start_loc
seqused_k
=
local_metadata
.
local_seqused_k
max_seqlen_q
=
local_metadata
.
local_max_query_len
max_seqlen_k
=
local_metadata
.
local_max_seq_len
block_table
=
local_metadata
.
local_block_table
else
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
seqused_k
=
attn_metadata
.
seq_lens
max_seqlen_q
=
attn_metadata
.
max_query_len
max_seqlen_k
=
attn_metadata
.
max_seq_len
block_table
=
attn_metadata
.
block_table
if
use_prefill_decode_attn
:
if
use_prefill_decode_attn
:
# Compute attention and update output up to `num_actual_tokens`.
# Compute attention and update output up to `num_actual_tokens`.
...
...
vllm/v1/attention/backends/utils.py
View file @
89cab4d0
...
@@ -272,11 +272,14 @@ def infer_global_hyperparameters(
...
@@ -272,11 +272,14 @@ def infer_global_hyperparameters(
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
# block_table_local : shape[local_virtual_batches, pages_per_local_batch]
def
make_local_attention_virtual_batches
(
def
make_local_attention_virtual_batches
(
attn_chunk_size
:
int
,
attn_chunk_size
:
int
,
query_start_loc_np
:
np
.
ndarray
,
common_attn_metadata
:
CommonAttentionMetadata
,
seq_lens_np
:
np
.
ndarray
,
block_table
:
torch
.
Tensor
,
block_size
:
int
=
0
,
block_size
:
int
=
0
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
torch
.
Tensor
]:
)
->
CommonAttentionMetadata
:
query_start_loc_np
=
common_attn_metadata
.
query_start_loc_cpu
.
numpy
()
seq_lens_np
=
common_attn_metadata
.
seq_lens_cpu
.
numpy
()
block_table
=
common_attn_metadata
.
block_table_tensor
device
=
common_attn_metadata
.
query_start_loc
.
device
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
q_seqlens
=
query_start_loc_np
[
1
:]
-
query_start_loc_np
[:
-
1
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
actual_batch_size
=
seq_lens_np
.
shape
[
0
]
...
@@ -339,6 +342,7 @@ def make_local_attention_virtual_batches(
...
@@ -339,6 +342,7 @@ def make_local_attention_virtual_batches(
attn_chunk_size
,
attn_chunk_size
,
dtype
=
np
.
int32
)
dtype
=
np
.
int32
)
seqlens_k_local
[
cu_num_blocks
-
1
]
=
tokens_in_last_block
seqlens_k_local
[
cu_num_blocks
-
1
]
=
tokens_in_last_block
num_computed_tokens_local
=
seqlens_k_local
-
seqlens_q_local
k_seqstarts_absolute
=
np
.
repeat
(
seq_lens_np
,
local_blocks
)
-
\
k_seqstarts_absolute
=
np
.
repeat
(
seq_lens_np
,
local_blocks
)
-
\
(
rarange
*
attn_chunk_size
+
\
(
rarange
*
attn_chunk_size
+
\
...
@@ -380,8 +384,22 @@ def make_local_attention_virtual_batches(
...
@@ -380,8 +384,22 @@ def make_local_attention_virtual_batches(
block_table_local
=
block_table
[
batch_indices
,
block_indices
]
\
block_table_local
=
block_table
[
batch_indices
,
block_indices
]
\
.
view
(
virtual_batches
,
-
1
)
.
view
(
virtual_batches
,
-
1
)
return
seqlens_q_local
,
cu_seqlens_q_local
,
seqlens_k_local
,
\
query_start_loc_cpu
=
torch
.
from_numpy
(
cu_seqlens_q_local
)
block_table_local
seq_lens_cpu
=
torch
.
from_numpy
(
seqlens_k_local
)
return
CommonAttentionMetadata
(
query_start_loc_cpu
=
query_start_loc_cpu
,
query_start_loc
=
query_start_loc_cpu
.
to
(
device
=
device
,
non_blocking
=
True
),
seq_lens_cpu
=
seq_lens_cpu
,
seq_lens
=
seq_lens_cpu
.
to
(
device
=
device
,
non_blocking
=
True
),
num_computed_tokens_cpu
=
torch
.
from_numpy
(
num_computed_tokens_local
),
num_reqs
=
len
(
seq_lens_cpu
),
num_actual_tokens
=
common_attn_metadata
.
num_actual_tokens
,
max_query_len
=
seqlens_q_local
.
max
(),
block_table_tensor
=
block_table_local
,
slot_mapping
=
common_attn_metadata
.
slot_mapping
,
)
def
split_decodes_and_prefills
(
def
split_decodes_and_prefills
(
...
...
vllm/v1/core/single_type_kv_cache_manager.py
View file @
89cab4d0
...
@@ -7,7 +7,8 @@ from typing import Callable
...
@@ -7,7 +7,8 @@ from typing import Callable
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheSpec
,
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
FullAttentionSpec
,
KVCacheSpec
,
MambaSpec
,
SlidingWindowSpec
)
MambaSpec
,
SlidingWindowSpec
)
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -256,8 +257,10 @@ class FullAttentionManager(SingleTypeKVCacheManager):
...
@@ -256,8 +257,10 @@ class FullAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec
:
KVCacheSpec
,
kv_cache_spec
:
KVCacheSpec
,
use_eagle
:
bool
,
use_eagle
:
bool
,
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
assert
isinstance
(
kv_cache_spec
,
FullAttentionSpec
),
(
assert
isinstance
(
"FullAttentionManager can only be used for full attention groups"
)
kv_cache_spec
,
(
FullAttentionSpec
,
ChunkedLocalAttentionSpec
)
),
"FullAttentionManager can only be used for full attention "
\
"and chunked local attention groups"
computed_blocks
:
tuple
[
list
[
KVCacheBlock
],
...]
=
tuple
(
computed_blocks
:
tuple
[
list
[
KVCacheBlock
],
...]
=
tuple
(
[]
for
_
in
range
(
len
(
kv_cache_group_ids
)))
[]
for
_
in
range
(
len
(
kv_cache_group_ids
)))
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
max_num_blocks
=
max_length
//
kv_cache_spec
.
block_size
...
@@ -432,6 +435,7 @@ class MambaManager(SingleTypeKVCacheManager):
...
@@ -432,6 +435,7 @@ class MambaManager(SingleTypeKVCacheManager):
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
spec_manager_map
:
dict
[
type
[
KVCacheSpec
],
type
[
SingleTypeKVCacheManager
]]
=
{
FullAttentionSpec
:
FullAttentionManager
,
FullAttentionSpec
:
FullAttentionManager
,
ChunkedLocalAttentionSpec
:
FullAttentionManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
SlidingWindowSpec
:
SlidingWindowManager
,
MambaSpec
:
MambaManager
,
MambaSpec
:
MambaManager
,
}
}
...
...
vllm/v1/kv_cache_interface.py
View file @
89cab4d0
...
@@ -125,6 +125,21 @@ class FullAttentionSpec(AttentionSpec):
...
@@ -125,6 +125,21 @@ class FullAttentionSpec(AttentionSpec):
return
merged_spec
return
merged_spec
@
dataclass
class
ChunkedLocalAttentionSpec
(
AttentionSpec
):
attention_chunk_size
:
int
def
max_memory_usage_bytes
(
self
,
vllm_config
:
VllmConfig
)
->
int
:
max_model_len
=
vllm_config
.
model_config
.
max_model_len
return
cdiv
(
max_model_len
,
self
.
block_size
)
*
self
.
page_size_bytes
@
property
def
type_id
(
self
)
->
str
:
return
(
f
"local_attention_
{
self
.
attention_chunk_size
}
_
{
self
.
block_size
}
_
{
self
.
page_size_bytes
}
"
)
# noqa
@
dataclass
@
dataclass
class
SlidingWindowSpec
(
AttentionSpec
):
class
SlidingWindowSpec
(
AttentionSpec
):
sliding_window
:
int
sliding_window
:
int
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
89cab4d0
...
@@ -44,11 +44,14 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
...
@@ -44,11 +44,14 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
GiB_bytes
,
LazyLoader
,
check_use_alibi
,
get_dtype_size
,
GiB_bytes
,
LazyLoader
,
check_use_alibi
,
get_dtype_size
,
is_pin_memory_available
,
round_up
)
is_pin_memory_available
,
round_up
)
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.mamba_attn
import
Mamba2AttentionBackend
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
from
vllm.v1.attention.backends.utils
import
(
CommonAttentionMetadata
)
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
make_local_attention_virtual_batches
)
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_encoder_budget
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
FullAttentionSpec
,
from
vllm.v1.kv_cache_interface
import
(
AttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
MambaSpec
,
ChunkedLocalAttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheSpec
,
MambaSpec
,
SlidingWindowSpec
)
SlidingWindowSpec
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
LogprobsTensors
,
ModelRunnerOutput
)
ModelRunnerOutput
)
...
@@ -705,6 +708,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -705,6 +708,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_decode_common_attn_metadata
is
None
:
spec_decode_common_attn_metadata
is
None
:
spec_decode_common_attn_metadata
=
common_attn_metadata
spec_decode_common_attn_metadata
=
common_attn_metadata
if
isinstance
(
kv_cache_group_spec
.
kv_cache_spec
,
ChunkedLocalAttentionSpec
):
common_attn_metadata
=
make_local_attention_virtual_batches
(
kv_cache_group_spec
.
kv_cache_spec
.
attention_chunk_size
,
common_attn_metadata
,
self
.
cache_config
.
block_size
)
# Prepare for cascade attention if enabled & beneficial.
# Prepare for cascade attention if enabled & beneficial.
common_prefix_len
=
0
common_prefix_len
=
0
builder
=
self
.
attn_metadata_builders
[
kv_cache_group_id
]
builder
=
self
.
attn_metadata_builders
[
kv_cache_group_id
]
...
@@ -2589,6 +2598,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2589,6 +2598,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO: Support other attention modules, e.g., cross-attention
# TODO: Support other attention modules, e.g., cross-attention
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
if
attn_module
.
attn_type
==
AttentionType
.
DECODER
:
use_local_attention
=
(
self
.
attention_chunk_size
is
not
None
and
attn_module
.
impl
.
use_irope
)
if
attn_module
.
sliding_window
is
not
None
:
if
attn_module
.
sliding_window
is
not
None
:
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
kv_cache_spec
[
layer_name
]
=
SlidingWindowSpec
(
block_size
=
block_size
,
block_size
=
block_size
,
...
@@ -2597,6 +2608,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -2597,6 +2608,14 @@ class GPUModelRunner(LoRAModelRunnerMixin):
dtype
=
self
.
kv_cache_dtype
,
dtype
=
self
.
kv_cache_dtype
,
sliding_window
=
attn_module
.
sliding_window
,
sliding_window
=
attn_module
.
sliding_window
,
use_mla
=
use_mla
)
use_mla
=
use_mla
)
elif
use_local_attention
:
kv_cache_spec
[
layer_name
]
=
(
ChunkedLocalAttentionSpec
(
block_size
=
block_size
,
num_kv_heads
=
attn_module
.
num_kv_heads
,
head_size
=
attn_module
.
head_size
,
dtype
=
self
.
kv_cache_dtype
,
attention_chunk_size
=
self
.
attention_chunk_size
,
use_mla
=
use_mla
))
else
:
else
:
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
kv_cache_spec
[
layer_name
]
=
FullAttentionSpec
(
block_size
=
block_size
,
block_size
=
block_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment