Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6f9d81d0
Unverified
Commit
6f9d81d0
authored
Nov 29, 2025
by
Isotr0py
Committed by
GitHub
Nov 28, 2025
Browse files
[V0 deprecation] Clean up legacy paged attention helper functions (#28043)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
fae69430
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
0 additions
and
334 deletions
+0
-334
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+0
-211
vllm/attention/ops/rocm_aiter_paged_attn.py
vllm/attention/ops/rocm_aiter_paged_attn.py
+0
-123
No files found.
vllm/attention/ops/paged_attn.py
View file @
6f9d81d0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
torch
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
if
current_platform
.
is_cuda_alike
():
from
vllm
import
_custom_ops
as
ops
elif
current_platform
.
is_xpu
():
from
vllm._ipex_ops
import
ipex_ops
as
ops
if
HAS_TRITON
:
from
vllm.attention.ops.prefix_prefill
import
context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE
=
512
@
dataclass
class
PagedAttentionMetadata
:
"""Metadata for PagedAttention."""
# (batch_size,). The length of sequences (entire tokens seen so far) per
# sequence.
seq_lens_tensor
:
torch
.
Tensor
|
None
# Maximum sequence length in the batch. 0 if it is prefill-only batch.
max_decode_seq_len
:
int
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables
:
torch
.
Tensor
|
None
class
PagedAttention
:
@
staticmethod
def
get_supported_head_sizes
()
->
list
[
int
]:
return
[
32
,
64
,
80
,
96
,
112
,
120
,
128
,
192
,
256
]
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
cache_dtype_str
:
str
=
"auto"
,
)
->
tuple
[
int
,
...]:
return
(
2
,
num_blocks
,
block_size
*
num_kv_heads
*
head_size
)
@
staticmethod
def
split_kv_cache
(
kv_cache
:
torch
.
Tensor
,
...
...
@@ -89,174 +49,3 @@ class PagedAttention:
k_scale
,
v_scale
,
)
@
staticmethod
def
forward_decode
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
max_seq_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
torch
.
Tensor
|
None
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
torch
.
Tensor
:
if
blocksparse_vert_stride
is
not
None
and
blocksparse_vert_stride
>
1
:
# use blocksparse paged attention
block_size
=
value_cache
.
size
(
-
1
)
assert
(
blocksparse_block_size
>
0
and
blocksparse_block_size
%
block_size
==
0
),
(
f
"
{
blocksparse_block_size
=
}
needs to be a multiple of"
f
"
{
block_size
=
}
used in block_tables."
)
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
3
]
num_seqs
,
num_heads
,
head_size
=
query
.
shape
max_num_partitions
=
(
max_seq_len
+
_PARTITION_SIZE
-
1
)
//
_PARTITION_SIZE
# NOTE(woosuk): We use a simple heuristic to decide whether to use
# PagedAttention V1 or V2. If the number of partitions is 1, we use
# V1 to avoid the overhead of reduction. Also, if the number of
# sequences or heads is large, we use V1 since there is enough work
# to parallelize.
# TODO(woosuk): Tune this heuristic.
# For context len > 8192, use V2 kernel to avoid shared memory shortage.
use_v1
=
max_seq_len
<=
8192
and
(
max_num_partitions
==
1
or
num_seqs
*
num_heads
>
512
)
if
use_v1
:
# Run PagedAttention V1.
ops
.
paged_attention_v1
(
output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
else
:
# Run PagedAttention V2.
assert
_PARTITION_SIZE
%
block_size
==
0
tmp_output
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
num_seqs
,
num_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_v2
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
,
block_tables
,
seq_lens
,
block_size
,
max_seq_len
,
alibi_slopes
,
kv_cache_dtype
,
k_scale
,
v_scale
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_block_size
,
blocksparse_head_sliding_step
,
)
return
output
@
staticmethod
def
forward_prefix
(
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
query_start_loc
:
torch
.
Tensor
,
seq_lens_tensor
:
torch
.
Tensor
,
max_query_len
:
int
,
alibi_slopes
:
torch
.
Tensor
|
None
,
sliding_window
:
int
|
None
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
max_seq_len
=
None
context_attention_fwd
(
query
,
key
,
value
,
output
,
kv_cache_dtype
,
key_cache
,
value_cache
,
block_tables
,
# query_start_loc is (batch_size + 1,)
query_start_loc
,
seq_lens_tensor
,
max_seq_len
,
max_query_len
,
k_scale
,
v_scale
,
alibi_slopes
,
sliding_window
,
)
return
output
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
src_key_cache
=
src_kv_cache
[
0
]
dst_key_cache
=
dst_kv_cache
[
0
]
ops
.
swap_blocks
(
src_key_cache
,
dst_key_cache
,
src_to_dst
)
src_value_cache
=
src_kv_cache
[
1
]
dst_value_cache
=
dst_kv_cache
[
1
]
ops
.
swap_blocks
(
src_value_cache
,
dst_value_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
list
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
key_caches
=
[
kv_cache
[
0
]
for
kv_cache
in
kv_caches
]
value_caches
=
[
kv_cache
[
1
]
for
kv_cache
in
kv_caches
]
ops
.
copy_blocks
(
key_caches
,
value_caches
,
src_to_dists
)
vllm/attention/ops/rocm_aiter_paged_attn.py
deleted
100644 → 0
View file @
fae69430
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
aiter
as
rocm_aiter
import
torch
from
vllm.attention.ops.paged_attn
import
PagedAttention
from
vllm.platforms
import
current_platform
from
vllm.utils.math_utils
import
cdiv
FP8_DTYPE
=
current_platform
.
fp8_dtype
()
class
AITERPagedAttention
(
PagedAttention
):
@
staticmethod
def
write_to_paged_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
)
->
None
:
if
kv_cache_dtype
not
in
[
"int8"
,
"fp8"
,
"fp8_e4m3"
]:
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
k_scale
,
v_scale
,
)
else
:
kv_cache_torch_dtype
=
FP8_DTYPE
if
"fp8"
in
kv_cache_dtype
else
torch
.
int8
key_cache
=
key_cache
.
view
(
kv_cache_torch_dtype
)
value_cache
=
value_cache
.
view
(
kv_cache_torch_dtype
)
rocm_aiter
.
reshape_and_cache_with_pertoken_quant
(
key
,
value
,
key_cache
,
value_cache
,
k_scale
,
v_scale
,
slot_mapping
.
flatten
(),
True
,
)
@
staticmethod
def
forward_decode
(
query
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
max_seq_len
:
int
,
kv_cache_dtype
:
str
,
num_kv_heads
:
int
,
scale
:
float
,
alibi_slopes
:
torch
.
Tensor
|
None
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_block_size
:
int
=
64
,
blocksparse_head_sliding_step
:
int
=
0
,
)
->
torch
.
Tensor
:
if
kv_cache_dtype
not
in
[
"int8"
,
"fp8"
,
"fp8_e4m3"
]:
return
PagedAttention
.
forward_decode
(
query
=
query
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
block_tables
=
block_tables
,
seq_lens
=
seq_lens
,
max_seq_len
=
max_seq_len
,
kv_cache_dtype
=
kv_cache_dtype
,
num_kv_heads
=
num_kv_heads
,
scale
=
scale
,
alibi_slopes
=
alibi_slopes
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
tp_rank
=
tp_rank
,
blocksparse_local_blocks
=
blocksparse_local_blocks
,
blocksparse_vert_stride
=
blocksparse_vert_stride
,
blocksparse_block_size
=
blocksparse_block_size
,
blocksparse_head_sliding_step
=
blocksparse_head_sliding_step
,
)
if
"fp8"
in
kv_cache_dtype
:
key_cache
=
key_cache
.
view
(
current_platform
.
fp8_dtype
())
value_cache
=
value_cache
.
view
(
current_platform
.
fp8_dtype
())
if
blocksparse_vert_stride
is
not
None
and
blocksparse_vert_stride
>
1
:
# use blocksparse paged attention
block_size
=
value_cache
.
size
(
-
1
)
assert
(
blocksparse_block_size
>
0
and
blocksparse_block_size
%
block_size
==
0
),
(
f
"
{
blocksparse_block_size
=
}
needs to be a multiple of"
f
"
{
block_size
=
}
used in block_tables."
)
output
=
torch
.
empty_like
(
query
)
block_size
=
value_cache
.
shape
[
3
]
max_num_blocks_per_seq
=
cdiv
(
max_seq_len
,
block_size
)
rocm_aiter
.
pa_fwd_asm
(
query
,
key_cache
,
value_cache
,
block_tables
,
seq_lens
,
max_num_blocks_per_seq
,
k_scale
,
v_scale
,
output
,
)
return
output
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment