Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
042da732
Unverified
Commit
042da732
authored
Dec 11, 2025
by
Lucas Wilkinson
Committed by
GitHub
Dec 11, 2025
Browse files
[Core] Refactor `_build_attention_metadata` (#29628)
Signed-off-by:
Lucas Wilkinson
<
lwilkins@redhat.com
>
parent
b5945d49
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
124 additions
and
126 deletions
+124
-126
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+124
-126
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
042da732
...
@@ -1534,28 +1534,13 @@ class GPUModelRunner(
...
@@ -1534,28 +1534,13 @@ class GPUModelRunner(
"""
"""
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
:return: tuple[attn_metadata, spec_decode_common_attn_metadata]
"""
"""
# Attention metadata is not needed for attention free models
if
len
(
self
.
kv_cache_config
.
kv_cache_groups
)
==
0
:
return
{},
None
num_tokens_padded
=
num_tokens_padded
or
num_tokens
num_tokens_padded
=
num_tokens_padded
or
num_tokens
num_reqs_padded
=
num_reqs_padded
or
num_reqs
num_reqs_padded
=
num_reqs_padded
or
num_reqs
assert
num_reqs_padded
is
not
None
and
num_tokens_padded
is
not
None
logits_indices_padded
=
None
num_logits_indices
=
None
if
logits_indices
is
not
None
:
num_logits_indices
=
logits_indices
.
size
(
0
)
if
self
.
cache_config
.
kv_sharing_fast_prefill
:
logits_indices_padded
=
self
.
_prepare_kv_sharing_fast_prefill
(
logits_indices
)
# update seq_lens of decode reqs under DCP.
if
self
.
dcp_world_size
>
1
:
self
.
dcp_local_seq_lens
.
cpu
[:
num_reqs
]
=
get_dcp_local_seq_lens
(
self
.
seq_lens
.
cpu
[:
num_reqs
],
self
.
dcp_world_size
,
self
.
dcp_rank
,
self
.
parallel_config
.
cp_kv_cache_interleave_size
,
)
self
.
dcp_local_seq_lens
.
cpu
[
num_reqs
:].
fill_
(
0
)
self
.
dcp_local_seq_lens
.
copy_to_gpu
(
num_reqs_padded
)
attn_metadata
:
PerLayerAttnMetadata
=
{}
attn_metadata
:
PerLayerAttnMetadata
=
{}
if
ubatch_slices
is
not
None
:
if
ubatch_slices
is
not
None
:
...
@@ -1576,36 +1561,12 @@ class GPUModelRunner(
...
@@ -1576,36 +1561,12 @@ class GPUModelRunner(
self
.
num_accepted_tokens
.
np
[
num_reqs
:].
fill
(
1
)
self
.
num_accepted_tokens
.
np
[
num_reqs
:].
fill
(
1
)
self
.
num_accepted_tokens
.
copy_to_gpu
()
self
.
num_accepted_tokens
.
copy_to_gpu
()
# Used in the below loop, uses padded shapes
kv_cache_groups
=
self
.
kv_cache_config
.
kv_cache_groups
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
]
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
]
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs_padded
]
seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
]
num_computed_tokens_cpu
=
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[
:
num_reqs_padded
]
dcp_local_seq_lens
,
dcp_local_seq_lens_cpu
=
None
,
None
if
self
.
dcp_world_size
>
1
:
dcp_local_seq_lens
=
self
.
dcp_local_seq_lens
.
gpu
[:
num_reqs_padded
]
dcp_local_seq_lens_cpu
=
self
.
dcp_local_seq_lens
.
cpu
[:
num_reqs_padded
]
spec_decode_common_attn_metadata
=
None
def
_get_block_table_and_slot_mapping
(
kv_cache_gid
:
int
):
assert
num_reqs_padded
is
not
None
and
num_tokens_padded
is
not
None
# Prepare the attention metadata for each KV cache group and make layers
kv_cache_spec
=
kv_cache_groups
[
kv_cache_gid
].
kv_cache_spec
# in the same group share the same metadata.
if
isinstance
(
kv_cache_spec
,
EncoderOnlyAttentionSpec
):
for
kv_cache_gid
,
kv_cache_group
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
encoder_seq_lens
,
encoder_seq_lens_cpu
=
self
.
_get_encoder_seq_lens
(
num_scheduled_tokens
or
{},
kv_cache_group
.
kv_cache_spec
,
num_reqs_padded
,
)
if
isinstance
(
kv_cache_group
.
kv_cache_spec
,
EncoderOnlyAttentionSpec
):
# Encoder-only layers do not have KV cache, so we need to
# create a dummy block table and slot mapping for them.
blk_table_tensor
=
torch
.
zeros
(
blk_table_tensor
=
torch
.
zeros
(
(
num_reqs_padded
,
1
),
(
num_reqs_padded
,
1
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
...
@@ -1621,92 +1582,129 @@ class GPUModelRunner(
...
@@ -1621,92 +1582,129 @@ class GPUModelRunner(
blk_table_tensor
=
blk_table
.
get_device_tensor
(
num_reqs_padded
)
blk_table_tensor
=
blk_table
.
get_device_tensor
(
num_reqs_padded
)
slot_mapping
=
blk_table
.
slot_mapping
.
gpu
[:
num_tokens_padded
]
slot_mapping
=
blk_table
.
slot_mapping
.
gpu
[:
num_tokens_padded
]
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# Fill unused with -1. Needed for reshape_and_cache in full cuda
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
# graph mode. `blk_table_tensor` -1 to match mamba PAD_SLOT_ID
slot_mapping
[
num_tokens
:
num_tokens_padded
].
fill_
(
-
1
)
slot_mapping
[
num_tokens
:
num_tokens_padded
].
fill_
(
-
1
)
blk_table_tensor
[
num_reqs
:
num_reqs_padded
].
fill_
(
-
1
)
blk_table_tensor
[
num_reqs
:
num_reqs_padded
].
fill_
(
-
1
)
common_attn_metadata
=
CommonAttentionMetadata
(
return
blk_table_tensor
,
slot_mapping
query_start_loc
=
query_start_loc
,
query_start_loc_cpu
=
query_start_loc_cpu
,
block_table_gid_0
,
slot_mapping_gid_0
=
_get_block_table_and_slot_mapping
(
0
)
seq_lens
=
seq_lens
,
cm_base
=
CommonAttentionMetadata
(
_seq_lens_cpu
=
seq_lens_cpu
,
query_start_loc
=
self
.
query_start_loc
.
gpu
[:
num_reqs_padded
+
1
],
_num_computed_tokens_cpu
=
num_computed_tokens_cpu
,
query_start_loc_cpu
=
self
.
query_start_loc
.
cpu
[:
num_reqs_padded
+
1
],
num_actual_tokens
=
num_tokens_padded
,
seq_lens
=
self
.
seq_lens
.
gpu
[:
num_reqs_padded
],
num_reqs
=
num_reqs_padded
,
_seq_lens_cpu
=
self
.
seq_lens
.
cpu
[:
num_reqs_padded
],
max_query_len
=
max_query_len
,
_num_computed_tokens_cpu
=
self
.
input_batch
.
num_computed_tokens_cpu_tensor
[
max_seq_len
=
max_seq_len
,
:
num_reqs_padded
block_table_tensor
=
blk_table_tensor
,
],
slot_mapping
=
slot_mapping
,
num_reqs
=
num_reqs_padded
,
logits_indices_padded
=
logits_indices_padded
,
num_actual_tokens
=
num_tokens_padded
,
num_logits_indices
=
num_logits_indices
,
max_query_len
=
max_query_len
,
causal
=
True
,
max_seq_len
=
max_seq_len
,
encoder_seq_lens
=
encoder_seq_lens
,
block_table_tensor
=
block_table_gid_0
,
encoder_seq_lens_cpu
=
encoder_seq_lens_cpu
,
slot_mapping
=
slot_mapping_gid_0
,
dcp_local_seq_lens
=
dcp_local_seq_lens
,
causal
=
True
,
dcp_local_seq_lens_cpu
=
dcp_local_seq_lens_cpu
,
)
if
self
.
dcp_world_size
>
1
:
self
.
dcp_local_seq_lens
.
cpu
[:
num_reqs
]
=
get_dcp_local_seq_lens
(
self
.
seq_lens
.
cpu
[:
num_reqs
],
self
.
dcp_world_size
,
self
.
dcp_rank
,
self
.
parallel_config
.
cp_kv_cache_interleave_size
,
)
)
self
.
dcp_local_seq_lens
.
cpu
[
num_reqs
:].
fill_
(
0
)
self
.
dcp_local_seq_lens
.
copy_to_gpu
(
num_reqs_padded
)
cm_base
.
dcp_local_seq_lens
=
self
.
dcp_local_seq_lens
.
gpu
[:
num_reqs_padded
]
cm_base
.
dcp_local_seq_lens_cpu
=
self
.
dcp_local_seq_lens
.
cpu
[
:
num_reqs_padded
]
if
logits_indices
is
not
None
and
self
.
cache_config
.
kv_sharing_fast_prefill
:
cm_base
.
num_logits_indices
=
logits_indices
.
size
(
0
)
cm_base
.
logits_indices_padded
=
self
.
_prepare_kv_sharing_fast_prefill
(
logits_indices
)
def
_build_attn_group_metadata
(
kv_cache_gid
:
int
,
attn_gid
:
int
,
common_attn_metadata
:
CommonAttentionMetadata
,
ubid
:
int
|
None
=
None
,
)
->
None
:
attn_group
=
self
.
attn_groups
[
kv_cache_gid
][
attn_gid
]
cascade_attn_prefix_len
=
(
cascade_attn_prefix_lens
[
kv_cache_gid
][
attn_gid
]
if
cascade_attn_prefix_lens
else
0
)
builder
=
attn_group
.
get_metadata_builder
(
ubid
or
0
)
extra_attn_metadata_args
=
{}
if
use_spec_decode
and
isinstance
(
builder
,
GDNAttentionMetadataBuilder
):
assert
ubid
is
None
,
"UBatching not supported with GDN yet"
extra_attn_metadata_args
=
dict
(
num_accepted_tokens
=
self
.
num_accepted_tokens
.
gpu
[:
num_reqs_padded
],
num_decode_draft_tokens_cpu
=
self
.
num_decode_draft_tokens
.
cpu
[
:
num_reqs_padded
],
)
if
for_cudagraph_capture
:
attn_metadata_i
=
builder
.
build_for_cudagraph_capture
(
common_attn_metadata
)
else
:
attn_metadata_i
=
builder
.
build
(
common_prefix_len
=
cascade_attn_prefix_len
,
common_attn_metadata
=
common_attn_metadata
,
**
extra_attn_metadata_args
,
)
if
ubid
is
None
:
assert
isinstance
(
attn_metadata
,
dict
)
attn_metadata_dict
=
attn_metadata
else
:
assert
isinstance
(
attn_metadata
,
list
)
attn_metadata_dict
=
attn_metadata
[
ubid
]
for
layer_name
in
attn_group
.
layer_names
:
attn_metadata_dict
[
layer_name
]
=
attn_metadata_i
# Prepare the attention metadata for each KV cache group and make layers
# in the same group share the same metadata.
spec_decode_common_attn_metadata
=
None
for
kv_cache_gid
,
kv_cache_group
in
enumerate
(
kv_cache_groups
):
cm
=
copy
(
cm_base
)
# shallow copy
# Basically only the encoder seq_lens, block_table and slot_mapping change
# for each kv_cache_group.
cm
.
encoder_seq_lens
,
cm
.
encoder_seq_lens_cpu
=
self
.
_get_encoder_seq_lens
(
num_scheduled_tokens
or
{},
kv_cache_group
.
kv_cache_spec
,
num_reqs_padded
,
)
if
kv_cache_gid
>
0
:
cm
.
block_table_tensor
,
cm
.
slot_mapping
=
(
_get_block_table_and_slot_mapping
(
kv_cache_gid
)
)
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
:
if
self
.
speculative_config
and
spec_decode_common_attn_metadata
is
None
:
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
isinstance
(
self
.
drafter
,
EagleProposer
):
if
self
.
drafter
.
attn_layer_names
[
0
]
in
kv_cache_group
.
layer_names
:
if
self
.
drafter
.
attn_layer_names
[
0
]
in
kv_cache_group
.
layer_names
:
spec_decode_common_attn_metadata
=
c
ommon_attn_metadata
spec_decode_common_attn_metadata
=
c
m
else
:
else
:
spec_decode_common_attn_metadata
=
common_attn_metadata
spec_decode_common_attn_metadata
=
cm
for
attn_gid
,
attn_group
in
enumerate
(
self
.
attn_groups
[
kv_cache_gid
]):
cascade_attn_prefix_len
=
(
cascade_attn_prefix_lens
[
kv_cache_gid
][
attn_gid
]
if
cascade_attn_prefix_lens
else
0
)
builder
=
attn_group
.
get_metadata_builder
()
extra_attn_metadata_args
=
{}
if
use_spec_decode
and
isinstance
(
builder
,
GDNAttentionMetadataBuilder
):
extra_attn_metadata_args
=
dict
(
num_accepted_tokens
=
self
.
num_accepted_tokens
.
gpu
[
:
num_reqs_padded
],
num_decode_draft_tokens_cpu
=
self
.
num_decode_draft_tokens
.
cpu
[
:
num_reqs_padded
],
)
for
attn_gid
in
range
(
len
(
self
.
attn_groups
[
kv_cache_gid
])):
if
ubatch_slices
is
not
None
:
if
ubatch_slices
is
not
None
:
common_attn_metadata_list
=
split_attn_metadata
(
for
ubid
,
_cm
in
enumerate
(
split_attn_metadata
(
ubatch_slices
,
cm
)):
ubatch_slices
,
common_attn_metadata
_build_attn_group_metadata
(
kv_cache_gid
,
attn_gid
,
_cm
,
ubid
)
)
for
ubid
,
common_attn_metadata
in
enumerate
(
common_attn_metadata_list
):
builder
=
attn_group
.
get_metadata_builder
(
ubatch_id
=
ubid
)
if
for_cudagraph_capture
:
attn_metadata_i
=
builder
.
build_for_cudagraph_capture
(
common_attn_metadata
)
else
:
attn_metadata_i
=
builder
.
build
(
common_prefix_len
=
cascade_attn_prefix_len
,
common_attn_metadata
=
common_attn_metadata
,
)
for
layer_name
in
kv_cache_group
.
layer_names
:
assert
type
(
attn_metadata
)
is
list
attn_metadata
[
ubid
][
layer_name
]
=
attn_metadata_i
else
:
else
:
assert
isinstance
(
attn_metadata
,
dict
)
_build_attn_group_metadata
(
kv_cache_gid
,
attn_gid
,
cm
)
if
for_cudagraph_capture
:
attn_metadata_i
=
builder
.
build_for_cudagraph_capture
(
common_attn_metadata
)
else
:
attn_metadata_i
=
builder
.
build
(
common_prefix_len
=
cascade_attn_prefix_len
,
common_attn_metadata
=
common_attn_metadata
,
**
extra_attn_metadata_args
,
)
for
layer_name
in
attn_group
.
layer_names
:
attn_metadata
[
layer_name
]
=
attn_metadata_i
if
self
.
is_mm_prefix_lm
:
if
self
.
is_mm_prefix_lm
:
req_doc_ranges
=
{}
req_doc_ranges
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment