Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
98f60e5a
Unverified
Commit
98f60e5a
authored
Jan 13, 2026
by
Matthew Bonanni
Committed by
GitHub
Jan 13, 2026
Browse files
[6/N][Attention] Move utils to more appropriate locations (#32215)
Signed-off-by:
Matthew Bonanni
<
mbonanni@redhat.com
>
parent
fefce498
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
168 additions
and
178 deletions
+168
-178
tests/v1/attention/test_attention_splitting.py
tests/v1/attention/test_attention_splitting.py
+4
-2
vllm/model_executor/layers/attention/chunked_local_attention.py
...odel_executor/layers/attention/chunked_local_attention.py
+1
-1
vllm/model_executor/layers/attention/cross_attention.py
vllm/model_executor/layers/attention/cross_attention.py
+0
-2
vllm/model_executor/layers/attention/encoder_only_attention.py
...model_executor/layers/attention/encoder_only_attention.py
+0
-2
vllm/model_executor/layers/attention/static_sink_attention.py
.../model_executor/layers/attention/static_sink_attention.py
+0
-2
vllm/model_executor/models/whisper_utils.py
vllm/model_executor/models/whisper_utils.py
+1
-3
vllm/v1/attention/backend.py
vllm/v1/attention/backend.py
+25
-1
vllm/v1/attention/backends/utils.py
vllm/v1/attention/backends/utils.py
+1
-159
vllm/v1/worker/gpu/cudagraph_utils.py
vllm/v1/worker/gpu/cudagraph_utils.py
+1
-1
vllm/v1/worker/gpu/spec_decode/eagle.py
vllm/v1/worker/gpu/spec_decode/eagle.py
+1
-1
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
vllm/v1/worker/ubatch_utils.py
vllm/v1/worker/ubatch_utils.py
+131
-0
vllm/v1/worker/utils.py
vllm/v1/worker/utils.py
+1
-2
No files found.
tests/v1/attention/test_attention_splitting.py
View file @
98f60e5a
...
@@ -7,13 +7,15 @@ import torch
...
@@ -7,13 +7,15 @@ import torch
from
tests.v1.attention.test_attention_backends
import
BATCH_SPECS
from
tests.v1.attention.test_attention_backends
import
BATCH_SPECS
from
tests.v1.attention.utils
import
BatchSpec
,
create_common_attn_metadata
from
tests.v1.attention.utils
import
BatchSpec
,
create_common_attn_metadata
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
split_decodes_and_prefills
,
)
from
vllm.v1.worker.ubatch_utils
import
(
UBatchSlice
,
UBatchSlice
,
_make_metadata_with_slice
,
_make_metadata_with_slice
,
maybe_create_ubatch_slices
,
slice_query_start_locs
,
slice_query_start_locs
,
split_attn_metadata
,
split_attn_metadata
,
split_decodes_and_prefills
,
)
)
from
vllm.v1.worker.ubatch_utils
import
maybe_create_ubatch_slices
@
pytest
.
fixture
@
pytest
.
fixture
...
...
vllm/model_executor/layers/attention/chunked_local_attention.py
View file @
98f60e5a
...
@@ -13,10 +13,10 @@ from vllm.v1.attention.backend import (
...
@@ -13,10 +13,10 @@ from vllm.v1.attention.backend import (
AttentionCGSupport
,
AttentionCGSupport
,
AttentionMetadataBuilder
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
subclass_attention_backend
,
)
)
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
make_local_attention_virtual_batches
,
make_local_attention_virtual_batches
,
subclass_attention_backend
,
)
)
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.kv_cache_interface
import
(
from
vllm.v1.kv_cache_interface
import
(
...
...
vllm/model_executor/layers/attention/cross_attention.py
View file @
98f60e5a
...
@@ -15,8 +15,6 @@ from vllm.v1.attention.backend import (
...
@@ -15,8 +15,6 @@ from vllm.v1.attention.backend import (
AttentionMetadata
,
AttentionMetadata
,
AttentionType
,
AttentionType
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
)
from
vllm.v1.attention.backends.utils
import
(
subclass_attention_backend
,
subclass_attention_backend
,
)
)
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.attention.selector
import
get_attn_backend
...
...
vllm/model_executor/layers/attention/encoder_only_attention.py
View file @
98f60e5a
...
@@ -13,8 +13,6 @@ from vllm.v1.attention.backend import (
...
@@ -13,8 +13,6 @@ from vllm.v1.attention.backend import (
AttentionMetadata
,
AttentionMetadata
,
AttentionType
,
AttentionType
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
)
from
vllm.v1.attention.backends.utils
import
(
subclass_attention_backend
,
subclass_attention_backend
,
)
)
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.attention.selector
import
get_attn_backend
...
...
vllm/model_executor/layers/attention/static_sink_attention.py
View file @
98f60e5a
...
@@ -16,8 +16,6 @@ from vllm.v1.attention.backend import (
...
@@ -16,8 +16,6 @@ from vllm.v1.attention.backend import (
AttentionMetadata
,
AttentionMetadata
,
AttentionType
,
AttentionType
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
)
from
vllm.v1.attention.backends.utils
import
(
subclass_attention_backend
,
subclass_attention_backend
,
)
)
from
vllm.v1.attention.ops.triton_reshape_and_cache_flash
import
(
from
vllm.v1.attention.ops.triton_reshape_and_cache_flash
import
(
...
...
vllm/model_executor/models/whisper_utils.py
View file @
98f60e5a
...
@@ -17,11 +17,9 @@ from vllm.v1.attention.backend import (
...
@@ -17,11 +17,9 @@ from vllm.v1.attention.backend import (
AttentionMetadata
,
AttentionMetadata
,
AttentionType
,
AttentionType
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.attention.backends.utils
import
(
subclass_attention_backend_with_overrides
,
subclass_attention_backend_with_overrides
,
)
)
from
vllm.v1.attention.backends.flash_attn
import
FlashAttentionBackend
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.attention.selector
import
get_attn_backend
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
...
...
vllm/v1/attention/backend.py
View file @
98f60e5a
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
TYPE_CHECKING
,
ClassVar
,
Generic
,
Protocol
,
TypeVar
,
get_args
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Generic
,
Protocol
,
TypeVar
,
get_args
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -734,3 +734,27 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
...
@@ -734,3 +734,27 @@ class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
def
is_quantized_kv_cache
(
kv_cache_dtype
:
str
)
->
bool
:
def
is_quantized_kv_cache
(
kv_cache_dtype
:
str
)
->
bool
:
return
kv_cache_dtype
!=
"auto"
return
kv_cache_dtype
!=
"auto"
def
subclass_attention_backend
(
name_prefix
:
str
,
attention_backend_cls
:
type
[
AttentionBackend
],
builder_cls
:
type
[
AttentionMetadataBuilder
[
M
]],
)
->
type
[
AttentionBackend
]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name
:
str
=
name_prefix
+
attention_backend_cls
.
__name__
# type: ignore
return
type
(
name
,
(
attention_backend_cls
,),
{
"get_builder_cls"
:
lambda
:
builder_cls
}
)
def
subclass_attention_backend_with_overrides
(
name_prefix
:
str
,
attention_backend_cls
:
type
[
AttentionBackend
],
overrides
:
dict
[
str
,
Any
],
)
->
type
[
AttentionBackend
]:
name
:
str
=
name_prefix
+
attention_backend_cls
.
__name__
# type: ignore
return
type
(
name
,
(
attention_backend_cls
,),
overrides
)
vllm/v1/attention/backends/utils.py
View file @
98f60e5a
...
@@ -8,7 +8,6 @@ from typing import (
...
@@ -8,7 +8,6 @@ from typing import (
Any
,
Any
,
Literal
,
Literal
,
Protocol
,
Protocol
,
TypeVar
,
get_args
,
get_args
,
)
)
...
@@ -33,10 +32,9 @@ from vllm.v1.attention.backend import (
...
@@ -33,10 +32,9 @@ from vllm.v1.attention.backend import (
AttentionBackend
,
AttentionBackend
,
AttentionImpl
,
AttentionImpl
,
AttentionMetadata
,
AttentionMetadata
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
CommonAttentionMetadata
,
subclass_attention_backend
,
)
)
from
vllm.v1.worker.ubatch_utils
import
UBatchSlice
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
KVCacheLayoutType
=
Literal
[
"NHD"
,
"HND"
]
KVCacheLayoutType
=
Literal
[
"NHD"
,
"HND"
]
...
@@ -49,135 +47,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
...
@@ -49,135 +47,6 @@ def is_valid_kv_cache_layout(value: str) -> bool:
return
value
in
get_args
(
KVCacheLayoutType
)
return
value
in
get_args
(
KVCacheLayoutType
)
def
slice_query_start_locs
(
query_start_loc
:
torch
.
Tensor
,
request_slice
:
slice
,
)
->
torch
.
Tensor
:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return
(
query_start_loc
[
request_slice
.
start
:
request_slice
.
stop
+
1
]
-
query_start_loc
[
request_slice
.
start
]
)
def
_make_metadata_with_slice
(
ubatch_slice
:
UBatchSlice
,
attn_metadata
:
CommonAttentionMetadata
)
->
CommonAttentionMetadata
:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
assert
not
ubatch_slice
.
is_empty
(),
f
"Ubatch slice
{
ubatch_slice
}
is empty"
request_slice
=
ubatch_slice
.
request_slice
token_slice
=
ubatch_slice
.
token_slice
start_locs
=
attn_metadata
.
query_start_loc_cpu
first_req
=
request_slice
.
start
first_tok
=
token_slice
.
start
last_req
=
request_slice
.
stop
-
1
last_tok
=
token_slice
.
stop
-
1
assert
start_locs
[
first_req
]
<=
first_tok
<
start_locs
[
first_req
+
1
],
(
"Token slice start outside of first request"
)
# NOTE: last token can be outside of the last request if we have CG padding.
# If the request is split across ubatches, we have to adjust the metadata.
# splits_first_request: The first request in this slice is the continuation of
# a request that started in a previous slice.
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request
=
first_tok
>
start_locs
[
first_req
]
splits_last_request
=
last_tok
<
start_locs
[
last_req
+
1
]
-
1
query_start_loc_cpu
=
slice_query_start_locs
(
start_locs
,
request_slice
)
query_start_loc
=
slice_query_start_locs
(
attn_metadata
.
query_start_loc
,
request_slice
)
assert
len
(
query_start_loc
)
>=
2
,
(
f
"query_start_loc must have at least 2 elements, got
{
len
(
query_start_loc
)
}
"
)
if
splits_first_request
:
tokens_skipped
=
first_tok
-
start_locs
[
first_req
]
query_start_loc
[
1
:]
-=
tokens_skipped
query_start_loc_cpu
[
1
:]
-=
tokens_skipped
seq_lens
=
attn_metadata
.
seq_lens
[
request_slice
]
seq_lens_cpu
=
attn_metadata
.
seq_lens_cpu
[
request_slice
]
if
splits_last_request
:
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped
=
start_locs
[
last_req
+
1
]
-
token_slice
.
stop
query_start_loc
[
-
1
]
-=
tokens_skipped
query_start_loc_cpu
[
-
1
]
-=
tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens
=
seq_lens
.
clone
()
seq_lens_cpu
=
seq_lens_cpu
.
clone
()
seq_lens
[
-
1
]
-=
tokens_skipped
seq_lens_cpu
[
-
1
]
-=
tokens_skipped
max_seq_len
=
int
(
seq_lens_cpu
.
max
())
num_computed_tokens_cpu
=
attn_metadata
.
num_computed_tokens_cpu
[
request_slice
]
num_requests
=
request_slice
.
stop
-
request_slice
.
start
num_actual_tokens
=
token_slice
.
stop
-
token_slice
.
start
max_query_len
=
int
(
torch
.
max
(
torch
.
abs
(
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
])).
item
()
)
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if
max_query_len
==
0
:
max_query_len
=
attn_metadata
.
max_query_len
block_table_tensor
=
attn_metadata
.
block_table_tensor
[
request_slice
]
slot_mapping
=
attn_metadata
.
slot_mapping
[
token_slice
]
return
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
seq_lens
,
num_reqs
=
num_requests
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_tensor
,
slot_mapping
=
slot_mapping
,
_seq_lens_cpu
=
seq_lens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
)
def
split_attn_metadata
(
ubatch_slices
:
list
[
UBatchSlice
],
common_attn_metadata
:
CommonAttentionMetadata
,
)
->
list
[
CommonAttentionMetadata
]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results
=
[]
for
ubatch_slice
in
ubatch_slices
:
results
.
append
(
_make_metadata_with_slice
(
ubatch_slice
,
common_attn_metadata
))
return
results
@
functools
.
lru_cache
@
functools
.
lru_cache
def
get_kv_cache_layout
():
def
get_kv_cache_layout
():
# Format specified by the code.
# Format specified by the code.
...
@@ -548,33 +417,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
...
@@ -548,33 +417,6 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
return
common_attn_metadata
return
common_attn_metadata
M
=
TypeVar
(
"M"
)
def
subclass_attention_backend
(
name_prefix
:
str
,
attention_backend_cls
:
type
[
AttentionBackend
],
builder_cls
:
type
[
AttentionMetadataBuilder
[
M
]],
)
->
type
[
AttentionBackend
]:
"""
Return a new subclass where `get_builder_cls` returns `builder_cls`.
"""
name
:
str
=
name_prefix
+
attention_backend_cls
.
__name__
# type: ignore
return
type
(
name
,
(
attention_backend_cls
,),
{
"get_builder_cls"
:
lambda
:
builder_cls
}
)
def
subclass_attention_backend_with_overrides
(
name_prefix
:
str
,
attention_backend_cls
:
type
[
AttentionBackend
],
overrides
:
dict
[
str
,
Any
],
)
->
type
[
AttentionBackend
]:
name
:
str
=
name_prefix
+
attention_backend_cls
.
__name__
# type: ignore
return
type
(
name
,
(
attention_backend_cls
,),
overrides
)
def
split_decodes_prefills_and_extends
(
def
split_decodes_prefills_and_extends
(
common_attn_metadata
:
CommonAttentionMetadata
,
common_attn_metadata
:
CommonAttentionMetadata
,
decode_threshold
:
int
=
1
,
decode_threshold
:
int
=
1
,
...
...
vllm/v1/worker/gpu/cudagraph_utils.py
View file @
98f60e5a
...
@@ -12,7 +12,7 @@ from vllm.config import VllmConfig
...
@@ -12,7 +12,7 @@ from vllm.config import VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.distributed.parallel_state
import
graph_capture
,
is_global_first_rank
from
vllm.distributed.parallel_state
import
graph_capture
,
is_global_first_rank
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.v1.attention.backend
s.utils
import
AttentionMetadataBuilder
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
...
...
vllm/v1/worker/gpu/spec_decode/eagle.py
View file @
98f60e5a
...
@@ -11,7 +11,7 @@ from vllm.forward_context import set_forward_context
...
@@ -11,7 +11,7 @@ from vllm.forward_context import set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.triton_utils
import
tl
,
triton
from
vllm.triton_utils
import
tl
,
triton
from
vllm.v1.attention.backend
s.utils
import
AttentionMetadataBuilder
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.attn_utils
import
build_attn_metadata
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.block_table
import
BlockTables
...
...
vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py
View file @
98f60e5a
...
@@ -6,7 +6,7 @@ import torch
...
@@ -6,7 +6,7 @@ import torch
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.config.compilation
import
CUDAGraphMode
from
vllm.v1.attention.backend
s.utils
import
AttentionMetadataBuilder
from
vllm.v1.attention.backend
import
AttentionMetadataBuilder
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.block_table
import
BlockTables
from
vllm.v1.worker.gpu.cudagraph_utils
import
(
from
vllm.v1.worker.gpu.cudagraph_utils
import
(
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
98f60e5a
...
@@ -112,7 +112,6 @@ from vllm.v1.attention.backends.utils import (
...
@@ -112,7 +112,6 @@ from vllm.v1.attention.backends.utils import (
create_fast_prefill_custom_backend
,
create_fast_prefill_custom_backend
,
get_dcp_local_seq_lens
,
get_dcp_local_seq_lens
,
reorder_batch_to_split_decodes_and_prefills
,
reorder_batch_to_split_decodes_and_prefills
,
split_attn_metadata
,
)
)
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.cudagraph_dispatcher
import
CudagraphDispatcher
from
vllm.v1.kv_cache_interface
import
(
from
vllm.v1.kv_cache_interface
import
(
...
@@ -165,6 +164,7 @@ from vllm.v1.worker.ubatch_utils import (
...
@@ -165,6 +164,7 @@ from vllm.v1.worker.ubatch_utils import (
UBatchSlices
,
UBatchSlices
,
check_ubatch_thresholds
,
check_ubatch_thresholds
,
maybe_create_ubatch_slices
,
maybe_create_ubatch_slices
,
split_attn_metadata
,
)
)
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.v1.worker.utils
import
is_residual_scattered_for_sp
from
vllm.v1.worker.workspace
import
lock_workspace
from
vllm.v1.worker.workspace
import
lock_workspace
...
...
vllm/v1/worker/ubatch_utils.py
View file @
98f60e5a
...
@@ -4,8 +4,10 @@ from dataclasses import dataclass
...
@@ -4,8 +4,10 @@ from dataclasses import dataclass
from
typing
import
TypeAlias
from
typing
import
TypeAlias
import
numpy
as
np
import
numpy
as
np
import
torch
from
vllm.config
import
ParallelConfig
from
vllm.config
import
ParallelConfig
from
vllm.v1.attention.backend
import
CommonAttentionMetadata
@
dataclass
@
dataclass
...
@@ -110,3 +112,132 @@ def maybe_create_ubatch_slices(
...
@@ -110,3 +112,132 @@ def maybe_create_ubatch_slices(
assert
sum
(
s
.
num_tokens
for
s
in
ubatch_slices_padded
)
==
num_tokens_padded
assert
sum
(
s
.
num_tokens
for
s
in
ubatch_slices_padded
)
==
num_tokens_padded
return
ubatch_slices
,
ubatch_slices_padded
return
ubatch_slices
,
ubatch_slices_padded
def
slice_query_start_locs
(
query_start_loc
:
torch
.
Tensor
,
request_slice
:
slice
,
)
->
torch
.
Tensor
:
"""
Creates a new query_start_loc that corresponds to the requests in
request_slice.
Note: This function creates a new tensor to hold the new query_start_locs.
This will break cudagraph compatibility.
"""
return
(
query_start_loc
[
request_slice
.
start
:
request_slice
.
stop
+
1
]
-
query_start_loc
[
request_slice
.
start
]
)
def
_make_metadata_with_slice
(
ubatch_slice
:
UBatchSlice
,
attn_metadata
:
CommonAttentionMetadata
)
->
CommonAttentionMetadata
:
"""
This function creates a new CommonAttentionMetadata that corresponds to
the requests included in ubatch_slice
"""
assert
not
ubatch_slice
.
is_empty
(),
f
"Ubatch slice
{
ubatch_slice
}
is empty"
request_slice
=
ubatch_slice
.
request_slice
token_slice
=
ubatch_slice
.
token_slice
start_locs
=
attn_metadata
.
query_start_loc_cpu
first_req
=
request_slice
.
start
first_tok
=
token_slice
.
start
last_req
=
request_slice
.
stop
-
1
last_tok
=
token_slice
.
stop
-
1
assert
start_locs
[
first_req
]
<=
first_tok
<
start_locs
[
first_req
+
1
],
(
"Token slice start outside of first request"
)
# NOTE: last token can be outside of the last request if we have CG padding.
# If the request is split across ubatches, we have to adjust the metadata.
# splits_first_request: The first request in this slice is the continuation of
# a request that started in a previous slice.
# splits_last_request: The last request in this slice continues into the
# next slice.
splits_first_request
=
first_tok
>
start_locs
[
first_req
]
splits_last_request
=
last_tok
<
start_locs
[
last_req
+
1
]
-
1
query_start_loc_cpu
=
slice_query_start_locs
(
start_locs
,
request_slice
)
query_start_loc
=
slice_query_start_locs
(
attn_metadata
.
query_start_loc
,
request_slice
)
assert
len
(
query_start_loc
)
>=
2
,
(
f
"query_start_loc must have at least 2 elements, got
{
len
(
query_start_loc
)
}
"
)
if
splits_first_request
:
tokens_skipped
=
first_tok
-
start_locs
[
first_req
]
query_start_loc
[
1
:]
-=
tokens_skipped
query_start_loc_cpu
[
1
:]
-=
tokens_skipped
seq_lens
=
attn_metadata
.
seq_lens
[
request_slice
]
seq_lens_cpu
=
attn_metadata
.
seq_lens_cpu
[
request_slice
]
if
splits_last_request
:
# NOTE: We use start_locs (the original query_start_loc_cpu) to calculate
# the tokens skipped because query_start_loc_cpu might have been modified
# if splits_first_request is True.
tokens_skipped
=
start_locs
[
last_req
+
1
]
-
token_slice
.
stop
query_start_loc
[
-
1
]
-=
tokens_skipped
query_start_loc_cpu
[
-
1
]
-=
tokens_skipped
# Make sure we don't modify the seq_lens tensors
# (not cudagraph compatible)
seq_lens
=
seq_lens
.
clone
()
seq_lens_cpu
=
seq_lens_cpu
.
clone
()
seq_lens
[
-
1
]
-=
tokens_skipped
seq_lens_cpu
[
-
1
]
-=
tokens_skipped
max_seq_len
=
int
(
seq_lens_cpu
.
max
())
num_computed_tokens_cpu
=
attn_metadata
.
num_computed_tokens_cpu
[
request_slice
]
num_requests
=
request_slice
.
stop
-
request_slice
.
start
num_actual_tokens
=
token_slice
.
stop
-
token_slice
.
start
max_query_len
=
int
(
torch
.
max
(
torch
.
abs
(
query_start_loc_cpu
[
1
:]
-
query_start_loc_cpu
[:
-
1
])).
item
()
)
# This is to account for the case where we are in a dummy
# run and query_start_loc_cpu is full of 0s
if
max_query_len
==
0
:
max_query_len
=
attn_metadata
.
max_query_len
block_table_tensor
=
attn_metadata
.
block_table_tensor
[
request_slice
]
slot_mapping
=
attn_metadata
.
slot_mapping
[
token_slice
]
return
CommonAttentionMetadata
(
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
seq_lens
=
seq_lens
,
num_reqs
=
num_requests
,
num_actual_tokens
=
num_actual_tokens
,
max_query_len
=
max_query_len
,
max_seq_len
=
max_seq_len
,
block_table_tensor
=
block_table_tensor
,
slot_mapping
=
slot_mapping
,
_seq_lens_cpu
=
seq_lens_cpu
,
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
)
def
split_attn_metadata
(
ubatch_slices
:
list
[
UBatchSlice
],
common_attn_metadata
:
CommonAttentionMetadata
,
)
->
list
[
CommonAttentionMetadata
]:
"""
Creates a new CommonAttentionMetadata instance that corresponds to the
requests for each UBatchSlice in ubatch_slices.
Note: This function does not modify common_attn_metadata
"""
results
=
[]
for
ubatch_slice
in
ubatch_slices
:
results
.
append
(
_make_metadata_with_slice
(
ubatch_slice
,
common_attn_metadata
))
return
results
vllm/v1/worker/utils.py
View file @
98f60e5a
...
@@ -16,8 +16,7 @@ from vllm.multimodal.cache import processor_only_cache_from_config
...
@@ -16,8 +16,7 @@ from vllm.multimodal.cache import processor_only_cache_from_config
from
vllm.multimodal.registry
import
MultiModalRegistry
from
vllm.multimodal.registry
import
MultiModalRegistry
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils.mem_utils
import
MemorySnapshot
,
format_gib
from
vllm.utils.mem_utils
import
MemorySnapshot
,
format_gib
from
vllm.v1.attention.backend
import
AttentionBackend
from
vllm.v1.attention.backend
import
AttentionBackend
,
AttentionMetadataBuilder
from
vllm.v1.attention.backends.utils
import
AttentionMetadataBuilder
from
vllm.v1.core.encoder_cache_manager
import
compute_mm_encoder_budget
from
vllm.v1.core.encoder_cache_manager
import
compute_mm_encoder_budget
from
vllm.v1.kv_cache_interface
import
KVCacheGroupSpec
,
KVCacheSpec
from
vllm.v1.kv_cache_interface
import
KVCacheGroupSpec
,
KVCacheSpec
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment