Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
71470bc4
Unverified
Commit
71470bc4
authored
Jul 31, 2025
by
Yong Hoon Shin
Committed by
GitHub
Jul 31, 2025
Browse files
[Misc] Add unit tests for chunked local attention (#21692)
Signed-off-by:
Yong Hoon Shin
<
yhshin@meta.com
>
parent
9e0726e5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
219 additions
and
13 deletions
+219
-13
tests/v1/attention/test_chunked_local_attention.py
tests/v1/attention/test_chunked_local_attention.py
+196
-0
tests/v1/attention/utils.py
tests/v1/attention/utils.py
+23
-13
No files found.
tests/v1/attention/test_chunked_local_attention.py
0 → 100644
View file @
71470bc4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
dataclasses
import
dataclass
import
numpy
as
np
import
pytest
import
torch
from
tests.v1.attention.utils
import
BatchSpec
,
create_common_attn_metadata
from
vllm.v1.attention.backends.utils
import
(
make_local_attention_virtual_batches
)
@
dataclass
class
LocalAttentionTestData
:
# Input parameters
batch_spec
:
BatchSpec
attn_chunk_size
:
int
block_size
:
int
# Expected return values
expected_q_seqlens
:
list
[
int
]
expected_k_seqlens
:
list
[
int
]
expected_local_block_table
:
list
[
list
[
int
]]
test_data_list
=
[
# Same as example in docstring of make_local_attention_virtual_batches
# except block table has 9 columns instead of 10
LocalAttentionTestData
(
batch_spec
=
BatchSpec
(
query_lens
=
[
4
,
10
,
5
],
seq_lens
=
[
6
,
17
,
9
],
),
attn_chunk_size
=
4
,
block_size
=
2
,
expected_q_seqlens
=
[
2
,
2
,
1
,
4
,
4
,
1
,
4
,
1
],
expected_k_seqlens
=
[
4
,
2
,
4
,
4
,
4
,
1
,
4
,
1
],
# 2 pages per local branch
# (chunk size 4 // block size 2)
expected_local_block_table
=
[
[
0
,
1
],
# local-batch 0, (batch 0, starting from k[0])
[
2
,
3
],
# local-batch 1, (batch 0, starting from k[4])
[
11
,
12
],
# local-batch 2, (batch 1, starting from k[4])
[
13
,
14
],
# local-batch 3, (batch 1, starting from k[8])
[
15
,
16
],
# local-batch 4, (batch 1, starting from k[12])
[
17
,
17
],
# local-batch 5, (batch 1, starting from k[16])
[
20
,
21
],
# local-batch 6, (batch 2, starting from k[4])
[
22
,
23
],
# local-batch 7, (batch 2, starting from k[8])
]),
# Case where block indices are not clipped to block table ncols-1
# because tokens_in_last_block == attn_chunk_size
LocalAttentionTestData
(
batch_spec
=
BatchSpec
(
query_lens
=
[
8
],
seq_lens
=
[
12
],
),
attn_chunk_size
=
4
,
block_size
=
2
,
expected_q_seqlens
=
[
4
,
4
],
expected_k_seqlens
=
[
4
,
4
],
expected_local_block_table
=
[
[
2
,
3
],
[
4
,
5
],
]),
# Case where all kv_seq positions are involved in attn
LocalAttentionTestData
(
batch_spec
=
BatchSpec
(
query_lens
=
[
7
],
# 10 - 7 = 3 previously computed tokens
seq_lens
=
[
10
],
),
attn_chunk_size
=
4
,
block_size
=
2
,
expected_q_seqlens
=
[
1
,
4
,
2
],
expected_k_seqlens
=
[
4
,
4
,
2
],
expected_local_block_table
=
[
[
0
,
1
],
[
2
,
3
],
[
4
,
4
],
]),
# Case where attn_chunk_size > kv_seq_len
# so no extra mini virtual batches are created
LocalAttentionTestData
(
batch_spec
=
BatchSpec
(
query_lens
=
[
4
],
seq_lens
=
[
6
],
),
# Larger than kv_seq_len
attn_chunk_size
=
10
,
block_size
=
2
,
# No change to q_seqlens and k_seqlens
expected_q_seqlens
=
[
4
],
expected_k_seqlens
=
[
6
],
# In this case, we only need a block-table like:
# block_table = [ [0, 1, 2] ] # 1 batch, 3 pages
# But we need to pad it to 5 pages per local batch
# because currently the pages_per_local_batch
# is calculated as (attn_chunk_size // block_size)
expected_local_block_table
=
[
[
0
,
1
,
2
,
2
,
2
],
]),
# Block size equal to chunk size
# Expect single page per batch in local batch table
LocalAttentionTestData
(
batch_spec
=
BatchSpec
(
query_lens
=
[
6
,
6
],
seq_lens
=
[
8
,
8
],
),
attn_chunk_size
=
4
,
block_size
=
4
,
expected_q_seqlens
=
[
2
,
4
,
2
,
4
],
expected_k_seqlens
=
[
4
,
4
,
4
,
4
],
# Initial block table = [
# [0, 1], < batch 0
# [2, 3], < batch 1
# ]
expected_local_block_table
=
[
[
0
],
# local-batch 0, (batch 0, starting from k[0])
[
1
],
# local-batch 1, (batch 0, starting from k[4])
[
2
],
# local-batch 1, (batch 0, starting from k[0])
[
3
],
# local-batch 1, (batch 0, starting from k[4])
]),
# Case where query falls in the second attention chunk
# k_toks > 0 1 2 3 4
# q_toks v _____________
# 0 | 1
# 1 | 1 1
# 2 | 1 1 1
# 3 | 1 1 1 1
# 4 | 1
# where tokens 0,1,2,3 have been pre-computed
LocalAttentionTestData
(
batch_spec
=
BatchSpec
(
query_lens
=
[
1
],
seq_lens
=
[
5
],
),
attn_chunk_size
=
4
,
block_size
=
2
,
expected_q_seqlens
=
[
1
],
expected_k_seqlens
=
[
1
],
expected_local_block_table
=
[
[
2
,
2
],
]),
]
@
pytest
.
mark
.
parametrize
(
"test_data"
,
test_data_list
)
def
test_local_attention_virtual_batches
(
test_data
:
LocalAttentionTestData
):
device
=
torch
.
device
(
"cuda:0"
)
batch_spec
=
test_data
.
batch_spec
attn_chunk_size
=
test_data
.
attn_chunk_size
block_size
=
test_data
.
block_size
expected_q_seqlens
=
test_data
.
expected_q_seqlens
expected_k_seqlens
=
test_data
.
expected_k_seqlens
expected_local_block_table
=
test_data
.
expected_local_block_table
# Create common attention metadata
common_attn_metadata
=
create_common_attn_metadata
(
batch_spec
,
block_size
,
device
,
# Use torch.arange instead of torch.randint so we can assert on
# block table tensor values. The block table will have shape
# (num_batches, cdiv(max_seq_len, block_size)) and the values will be
# aranged from 0 to cdiv(max_seq_len, block_size)-1
arange_block_indices
=
True
,
)
# Call the function
result
=
make_local_attention_virtual_batches
(
attn_chunk_size
,
common_attn_metadata
,
block_size
)
# Convert to numpy for easier comparison
actual_q_seqlens
=
np
.
diff
(
result
.
query_start_loc_cpu
.
numpy
())
actual_k_seqlens
=
result
.
seq_lens_cpu
.
numpy
()
# Check that all query lengths are less than or equal to attn_chunk_size
assert
all
(
q_len
<=
attn_chunk_size
for
q_len
in
actual_q_seqlens
)
# Check that all key lengths are less than or equal to attn_chunk_size
assert
all
(
k_len
<=
attn_chunk_size
for
k_len
in
actual_k_seqlens
)
# Check that the total number of query tokens is preserved
assert
sum
(
actual_q_seqlens
)
==
sum
(
batch_spec
.
query_lens
)
# Verify results
np
.
testing
.
assert_array_equal
(
actual_q_seqlens
,
expected_q_seqlens
)
np
.
testing
.
assert_array_equal
(
actual_k_seqlens
,
expected_k_seqlens
)
expected_block_table_tensor
=
\
torch
.
tensor
(
expected_local_block_table
,
dtype
=
torch
.
int32
,
device
=
device
)
print
(
f
"Expected block table:
\n
{
expected_block_table_tensor
}
"
)
print
(
f
"Actual block table:
\n
{
result
.
block_table_tensor
}
"
)
torch
.
testing
.
assert_close
(
result
.
block_table_tensor
,
expected_block_table_tensor
)
tests/v1/attention/utils.py
View file @
71470bc4
...
...
@@ -40,7 +40,8 @@ def create_common_attn_metadata(
batch_spec
:
BatchSpec
,
block_size
:
int
,
device
:
torch
.
device
,
max_block_idx
:
int
=
1000
)
->
CommonAttentionMetadata
:
max_block_idx
:
int
=
1000
,
arange_block_indices
:
bool
=
False
)
->
CommonAttentionMetadata
:
"""Create CommonAttentionMetadata from a BatchSpec and ModelParams."""
# Create query start locations
query_start_loc
=
torch
.
zeros
(
batch_spec
.
batch_size
+
1
,
...
...
@@ -65,19 +66,28 @@ def create_common_attn_metadata(
]
num_computed_tokens_cpu
=
torch
.
tensor
(
context_lens
,
dtype
=
torch
.
int32
)
# Create block table
(r
and
om for test
ing
)
# Create block table and
slot mapp
ing
max_blocks
=
(
max
(
batch_spec
.
seq_lens
)
+
block_size
-
1
)
//
block_size
block_table_tensor
=
torch
.
randint
(
0
,
max_block_idx
,
(
batch_spec
.
batch_size
,
max_blocks
),
dtype
=
torch
.
int32
,
device
=
device
)
# Create slot mapping
slot_mapping
=
torch
.
randint
(
0
,
max_block_idx
,
(
num_tokens
,
),
dtype
=
torch
.
int64
,
device
=
device
)
if
arange_block_indices
:
num_blocks
=
batch_spec
.
batch_size
*
max_blocks
block_table_tensor
=
torch
.
arange
(
num_blocks
,
dtype
=
torch
.
int32
,
device
=
device
).
view
(
batch_spec
.
batch_size
,
max_blocks
)
slot_mapping
=
torch
.
arange
(
num_tokens
,
dtype
=
torch
.
int64
,
device
=
device
).
view
(
num_tokens
)
else
:
block_table_tensor
=
torch
.
randint
(
0
,
max_block_idx
,
(
batch_spec
.
batch_size
,
max_blocks
),
dtype
=
torch
.
int32
,
device
=
device
)
slot_mapping
=
torch
.
randint
(
0
,
max_block_idx
,
(
num_tokens
,
),
dtype
=
torch
.
int64
,
device
=
device
)
# Calculate max query length
max_query_len
=
max
(
batch_spec
.
query_lens
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment