Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0edaf752
Unverified
Commit
0edaf752
authored
Aug 01, 2025
by
Sage Moore
Committed by
GitHub
Aug 01, 2025
Browse files
[Attention][DBO] Add support for "splitting" the CommonAttentionMetadata (#21153)
Signed-off-by:
Sage Moore
<
sage@neuralmagic.com
>
parent
6e8d8c4a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
240 additions
and
0 deletions
+240
-0
tests/v1/attention/test_attention_splitting.py
tests/v1/attention/test_attention_splitting.py
+157
-0
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+83
-0
No files found.
tests/v1/attention/test_attention_splitting.py
0 → 100644
View file @
0edaf752
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
tests.v1.attention.test_attention_backends
import
BATCH_SPECS
from
tests.v1.attention.utils
import
create_common_attn_metadata
from
vllm.v1.attention.backends.utils
import
(
UbatchSlice
,
_make_metadata_with_slice
,
slice_query_start_locs
,
split_attn_metadata
)
@
pytest
.
fixture
def
sample_query_start_loc
():
"""Sample query_start_loc tensor for testing"""
return
torch
.
tensor
([
0
,
5
,
12
,
20
,
35
,
50
])
def
test_basic_slice_middle
(
sample_query_start_loc
):
"""Test slicing from middle of tensor"""
req_slice
=
slice
(
1
,
3
)
# slice from index 1 to 3
result
=
slice_query_start_locs
(
sample_query_start_loc
,
req_slice
)
expected
=
torch
.
tensor
([
0
,
7
,
15
])
assert
torch
.
equal
(
result
,
expected
)
def
test_slice_from_beginning
(
sample_query_start_loc
):
"""Test slicing from the beginning of tensor"""
req_slice
=
slice
(
0
,
2
)
# slice from index 0 to 2
result
=
slice_query_start_locs
(
sample_query_start_loc
,
req_slice
)
expected
=
torch
.
tensor
([
0
,
5
,
12
])
assert
torch
.
equal
(
result
,
expected
)
def
test_slice_to_end
(
sample_query_start_loc
):
"""Test slicing to the end of tensor"""
req_slice
=
slice
(
3
,
5
)
# slice from index 3 to 5 (last index)
result
=
slice_query_start_locs
(
sample_query_start_loc
,
req_slice
)
expected
=
torch
.
tensor
([
0
,
15
,
30
])
assert
torch
.
equal
(
result
,
expected
)
def
test_single_element_slice
(
sample_query_start_loc
):
"""Test slice that results in single element"""
req_slice
=
slice
(
2
,
3
)
# slice from index 2 to 3
result
=
slice_query_start_locs
(
sample_query_start_loc
,
req_slice
)
expected
=
torch
.
tensor
([
0
,
8
])
assert
torch
.
equal
(
result
,
expected
)
def
test_full_tensor_slice
(
sample_query_start_loc
):
"""Test slicing the entire tensor"""
req_slice
=
slice
(
0
,
5
)
# slice entire tensor
result
=
slice_query_start_locs
(
sample_query_start_loc
,
req_slice
)
expected
=
torch
.
tensor
([
0
,
5
,
12
,
20
,
35
,
50
])
assert
torch
.
equal
(
result
,
expected
)
def
test_slice_bounds_edge_cases
(
sample_query_start_loc
):
# Test slice that goes exactly to the last element
req_slice
=
slice
(
4
,
5
)
# Last index
result
=
slice_query_start_locs
(
sample_query_start_loc
,
req_slice
)
expected
=
torch
.
tensor
([
0
,
15
])
assert
torch
.
equal
(
result
,
expected
)
@
pytest
.
fixture
def
small_decode_metadata
():
"""Create metadata for small decode batch"""
batch_spec
=
BATCH_SPECS
[
"small_decode"
]
device
=
torch
.
device
(
"cpu"
)
return
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
)
@
pytest
.
fixture
def
large_decode_metadata
():
"""Create metadata for small decode batch"""
batch_spec
=
BATCH_SPECS
[
"large_decode"
]
device
=
torch
.
device
(
"cpu"
)
return
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
)
@
pytest
.
fixture
def
mixed_small_metadata
():
"""Create metadata for mixed small batch"""
batch_spec
=
BATCH_SPECS
[
"mixed_small"
]
device
=
torch
.
device
(
"cpu"
)
return
create_common_attn_metadata
(
batch_spec
,
block_size
=
16
,
device
=
device
)
# Tests for _make_metadata_with_slice
def
test_make_metadata_with_slice_decode_batch
(
small_decode_metadata
):
"""Test slicing decode batch metadata"""
# Split first request only
ubatch_slice
=
UbatchSlice
(
slice
(
0
,
1
),
slice
(
0
,
1
))
result
=
_make_metadata_with_slice
(
ubatch_slice
,
small_decode_metadata
)
# Check sliced results
assert
result
.
num_reqs
==
1
# slice(0, 1) gives 1 requests
assert
result
.
num_actual_tokens
==
1
# slice(0, 1) gives 1 token
assert
result
.
max_query_len
==
1
assert
torch
.
equal
(
result
.
query_start_loc
,
torch
.
tensor
([
0
,
1
]))
assert
torch
.
equal
(
result
.
seq_lens
,
torch
.
tensor
([
32
]))
def
test_make_metadata_with_slice_mixed_batch
(
mixed_small_metadata
):
"""Test slicing mixed batch metadata"""
ubatch_slice
=
UbatchSlice
(
slice
(
1
,
3
),
slice
(
1
,
7
))
# Requests 1-3, tokens 1-7
result
=
_make_metadata_with_slice
(
ubatch_slice
,
mixed_small_metadata
)
assert
result
.
num_reqs
==
2
# slice(1, 3) gives 2 requests
assert
result
.
num_actual_tokens
==
6
# slice(1, 7) gives 6 tokens
assert
result
.
max_query_len
==
5
assert
torch
.
equal
(
result
.
query_start_loc
,
torch
.
tensor
([
0
,
1
,
6
]))
assert
torch
.
equal
(
result
.
seq_lens
,
torch
.
tensor
([
40
,
48
]))
def
test_split_attn_metadata_decode_batch
(
large_decode_metadata
):
"""Test splitting decode batch into two equal parts"""
num_tokens
=
large_decode_metadata
.
num_reqs
mid_point
=
num_tokens
//
2
ubatch_slices
=
[
UbatchSlice
(
slice
(
0
,
mid_point
),
slice
(
0
,
mid_point
)),
UbatchSlice
(
slice
(
mid_point
,
num_tokens
),
slice
(
mid_point
,
num_tokens
)),
]
results
=
split_attn_metadata
(
ubatch_slices
,
large_decode_metadata
)
assert
len
(
results
)
==
2
# Check first split
assert
results
[
0
].
num_reqs
==
mid_point
assert
results
[
0
].
num_actual_tokens
==
mid_point
assert
torch
.
equal
(
results
[
0
].
seq_lens
,
torch
.
tensor
([
2048
]
*
mid_point
))
# Check second split
assert
results
[
1
].
num_reqs
==
mid_point
assert
results
[
1
].
num_actual_tokens
==
mid_point
assert
torch
.
equal
(
results
[
1
].
seq_lens
,
torch
.
tensor
([
2048
]
*
mid_point
))
vllm/v1/attention/backends/utils.py
View file @
0edaf752
...
@@ -63,6 +63,89 @@ class CommonAttentionMetadata:
...
@@ -63,6 +63,89 @@ class CommonAttentionMetadata:
causal
:
bool
=
True
causal
:
bool
=
True
@
dataclass
class
UbatchSlice
:
request_slice
:
slice
token_slice
:
slice
def
slice_query_start_locs
(
query_start_loc
:
torch
.
Tensor
,
request_slice
:
slice
,
)
->
torch
.
Tensor
:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return
query_start_loc
[
request_slice
.
start
:
request_slice
.
stop
+
1
]
-
\
query_start_loc
[
request_slice
.
start
]
def
_make_metadata_with_slice
(
ubatch_slice
:
UbatchSlice
,
attn_metadata
:
CommonAttentionMetadata
)
->
CommonAttentionMetadata
:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
request_slice
=
ubatch_slice
.
request_slice
token_slice
=
ubatch_slice
.
token_slice
query_start_loc
=
slice_query_start_locs
(
attn_metadata
.
query_start_loc
,
request_slice
)
assert
len
(
query_start_loc
>=
2
)
query_start_loc_cpu
=
slice_query_start_locs
(
attn_metadata
.
query_start_loc_cpu
,
request_slice
)
seq_lens
=
attn_metadata
.
seq_lens
[
request_slice
]
seq_lens_cpu
=
attn_metadata
.
seq_lens_cpu
[
request_slice
]
num_computed_tokens_cpu
=
attn_metadata
.
num_computed_tokens_cpu
[
request_slice
]
num_requests
=
request_slice
.
stop
-
request_slice
.
start
num_actual_tokens
=
token_slice
.
stop
-
token_slice
.
start
max_query_len
=
int
(
torch
.
max
(
torch
.
abs
(
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
])).
item
())
block_table_tensor
=
attn_metadata
.
block_table_tensor
[
request_slice
]
slot_mapping
=
attn_metadata
.
slot_mapping
[
token_slice
]
return
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
seq_lens
,
seq_lens_cpu
=
seq_lens_cpu
,
num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
num_reqs
=
num_requests
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
block_table_tensor
=
block_table_tensor
,
slot_mapping
=
slot_mapping
,
)
def
split_attn_metadata
(
ubatch_slices
:
list
[
UbatchSlice
],
common_attn_metadata
:
CommonAttentionMetadata
,
)
->
list
[
CommonAttentionMetadata
]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UbatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results
=
[]
for
ubatch_slice
in
ubatch_slices
:
results
.
append
(
_make_metadata_with_slice
(
ubatch_slice
,
common_attn_metadata
))
return
results
M
=
TypeVar
(
"M"
)
M
=
TypeVar
(
"M"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment